diff --git a/.lintrunner.toml b/.lintrunner.toml index 2835af1bf92..a26e690c3c0 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -378,3 +378,31 @@ command = [ '--', '@{{PATHSFILE}}', ] + +[[linter]] +code = "TORCH_AO_IMPORT" +include_patterns = ["**/*.py"] +exclude_patterns = [ + "third-party/**", +] + +command = [ + "python3", + "-m", + "lintrunner_adapters", + "run", + "grep_linter", + "--pattern=\\bfrom torch\\.ao\\.quantization\\.(?:quantizer|observer|quantize_pt2e|pt2e)(?:\\.[A-Za-z0-9_]+)*\\b", + "--linter-name=TorchAOImport", + "--error-name=Prohibited torch.ao.quantization import", + """--error-description=\ + Imports from torch.ao.quantization are not allowed. \ + Please import from torchao.quantization.pt2e instead.\n \ + * torchao.quantization.pt2e (includes all the utils, including observers, fake quants etc.) \n \ + * torchao.quantization.pt2e.quantizer (quantizer related objects and utils) \n \ + * torchao.quantization.pt2e.quantize_pt2e (prepare_pt2e, prepare_qat_pt2e, convert_pt2e) \n\n \ + If you need something from torch.ao.quantization, you can add your file to an exclude_patterns for TORCH_AO_IMPORT in .lintrunner.toml. \ + """, + "--", + "@{{PATHSFILE}}", +] diff --git a/.mypy.ini b/.mypy.ini index 5ee07ddb2bf..cd14cbac7ea 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -97,3 +97,6 @@ ignore_missing_imports = True [mypy-zstd] ignore_missing_imports = True + +[mypy-torchao.*] +follow_untyped_imports = True diff --git a/backends/apple/coreml/test/test_coreml_quantizer.py b/backends/apple/coreml/test/test_coreml_quantizer.py index db75631dbc8..d5754328796 100644 --- a/backends/apple/coreml/test/test_coreml_quantizer.py +++ b/backends/apple/coreml/test/test_coreml_quantizer.py @@ -15,12 +15,12 @@ ) from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer -from torch.ao.quantization.quantize_pt2e import ( +from torch.export import export_for_training +from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, prepare_qat_pt2e, ) -from torch.export import export_for_training class TestCoreMLQuantizer: diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 8b052116ed8..0d4c7d91ae8 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -30,25 +30,24 @@ is_vgf, ) # usort: skip from executorch.exir.backend.compile_spec_schema import CompileSpec -from torch.ao.quantization.fake_quantize import ( +from torch.fx import GraphModule, Node +from torchao.quantization.pt2e import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, -) -from torch.ao.quantization.observer import ( HistogramObserver, MinMaxObserver, MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, + ObserverOrFakeQuantizeConstructor, PerChannelMinMaxObserver, PlaceholderObserver, ) -from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, +from torchao.quantization.pt2e.quantizer import ( + annotate_input_qspec_map, + annotate_output_qspec, + QuantizationSpec, + Quantizer, ) -from torch.fx import GraphModule, Node __all__ = [ "TOSAQuantizer", @@ -97,7 +96,7 @@ def get_symmetric_quantization_config( weight_qscheme = ( torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric ) - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = ( MinMaxObserver ) if is_qat: @@ -337,14 +336,14 @@ def _annotate_io( if is_annotated(node): continue if node.op == "placeholder" and len(node.users) > 0: - _annotate_output_qspec( + annotate_output_qspec( node, quantization_config.get_output_act_qspec(), ) mark_node_as_annotated(node) if node.op == "output": parent = node.all_input_nodes[0] - _annotate_input_qspec_map( + annotate_input_qspec_map( node, parent, quantization_config.get_input_act_qspec() ) mark_node_as_annotated(node) diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 0ce11b620a6..d6eb72f1148 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -15,10 +15,10 @@ import torch from torch._subclasses import FakeTensor - -from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.fx import GraphModule, Node +from torchao.quantization.pt2e.quantizer import QuantizationAnnotation + def is_annotated(node: Node) -> bool: """Given a node return whether the node is annotated.""" diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 55c2ca21e1b..8d04227e1a7 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -12,12 +12,13 @@ import torch.fx from executorch.backends.arm.quantizer import QuantizationConfig from executorch.backends.arm.tosa_utils import get_node_debug_info -from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) from torch.fx import Node +from torchao.quantization.pt2e.quantizer import ( + annotate_input_qspec_map, + annotate_output_qspec, + QuantizationSpecBase, + SharedQuantizationSpec, +) from .arm_quantizer_utils import ( is_annotated, @@ -118,7 +119,7 @@ def _annotate_input(node: Node, quant_property: _QuantProperty): strict=True, ): assert isinstance(n_arg, Node) - _annotate_input_qspec_map(node, n_arg, qspec) + annotate_input_qspec_map(node, n_arg, qspec) if quant_property.mark_annotated: mark_node_as_annotated(n_arg) # type: ignore[attr-defined] @@ -129,7 +130,7 @@ def _annotate_output(node: Node, quant_property: _QuantProperty): assert not quant_property.optional assert quant_property.index == 0, "Only one output annotation supported currently" - _annotate_output_qspec(node, quant_property.qspec) + annotate_output_qspec(node, quant_property.qspec) def _match_pattern( diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index 65435ac7c63..54698d058c4 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -9,9 +9,9 @@ from dataclasses import dataclass import torch -from torch.ao.quantization import ObserverOrFakeQuantize +from torchao.quantization.pt2e import ObserverOrFakeQuantize -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, FixedQParamsQuantizationSpec, QuantizationSpec, diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 76d4950be6d..e3efec02253 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -19,8 +19,8 @@ ) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.xnnpack.test.tester import Quantize -from torch.ao.quantization.observer import HistogramObserver -from torch.ao.quantization.quantizer import QuantizationSpec +from torchao.quantization.pt2e.observer import HistogramObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec aten_op = "torch.ops.aten.add.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" diff --git a/backends/arm/test/ops/test_sigmoid_16bit.py b/backends/arm/test/ops/test_sigmoid_16bit.py index ff0fe9cc4c1..5d32bdaf279 100644 --- a/backends/arm/test/ops/test_sigmoid_16bit.py +++ b/backends/arm/test/ops/test_sigmoid_16bit.py @@ -18,8 +18,8 @@ ) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.xnnpack.test.tester import Quantize -from torch.ao.quantization.observer import HistogramObserver -from torch.ao.quantization.quantizer import QuantizationSpec +from torchao.quantization.pt2e.observer import HistogramObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec def _get_16_bit_quant_config(): diff --git a/backends/arm/test/ops/test_sigmoid_32bit.py b/backends/arm/test/ops/test_sigmoid_32bit.py index 4edfdd6c23e..ad44fb52f6d 100644 --- a/backends/arm/test/ops/test_sigmoid_32bit.py +++ b/backends/arm/test/ops/test_sigmoid_32bit.py @@ -14,8 +14,8 @@ ) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.xnnpack.test.tester import Quantize -from torch.ao.quantization.observer import HistogramObserver -from torch.ao.quantization.quantizer import QuantizationSpec +from torchao.quantization.pt2e.observer import HistogramObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec def _get_16_bit_quant_config(): diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 594c4189b3a..438f07ba15f 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -37,9 +37,9 @@ from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass from executorch.exir.program._program import to_edge_with_preserved_ops from torch._inductor.decomposition import remove_decompositions -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export.exported_program import ExportedProgram +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from .passes import get_cadence_passes diff --git a/backends/cadence/aot/quantizer/TARGETS b/backends/cadence/aot/quantizer/TARGETS index 75eab631dd4..f2d6e5572b7 100644 --- a/backends/cadence/aot/quantizer/TARGETS +++ b/backends/cadence/aot/quantizer/TARGETS @@ -21,6 +21,7 @@ python_library( deps = [ ":utils", "//caffe2:torch", + "//pytorch/ao:torchao", ], ) diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 66f6772d942..cd6a7287793 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -15,7 +15,7 @@ from torch import fx from torch._ops import OpOverload -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, SharedQuantizationSpec, ) diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 761b2bf8d31..3fbe1bcc0fd 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -38,9 +38,9 @@ from torch import fx -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver -from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer +from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver +from torchao.quantization.pt2e.quantizer import DerivedQuantizationSpec, Quantizer +from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer act_qspec_asym8s = QuantizationSpec( diff --git a/backends/cadence/aot/quantizer/utils.py b/backends/cadence/aot/quantizer/utils.py index 0f9c9399780..fad5ca41e22 100644 --- a/backends/cadence/aot/quantizer/utils.py +++ b/backends/cadence/aot/quantizer/utils.py @@ -14,13 +14,13 @@ import torch from torch import fx from torch._ops import OpOverload -from torch.ao.quantization import ObserverOrFakeQuantize from torch.fx import GraphModule from torch.fx.passes.utils.source_matcher_utils import ( check_subgraphs_connected, SourcePartition, ) +from torchao.quantization.pt2e import ObserverOrFakeQuantize def quantize_tensor_multiplier( diff --git a/backends/cortex_m/test/test_replace_quant_nodes.py b/backends/cortex_m/test/test_replace_quant_nodes.py index f993b42c920..54f5142add8 100644 --- a/backends/cortex_m/test/test_replace_quant_nodes.py +++ b/backends/cortex_m/test/test_replace_quant_nodes.py @@ -16,15 +16,15 @@ ReplaceQuantNodesPass, ) from executorch.exir.dialects._ops import ops as exir_ops -from torch.ao.quantization.observer import HistogramObserver -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.quantizer import ( +from torch.export import export, export_for_training +from torch.fx import GraphModule +from torchao.quantization.pt2e.observer import HistogramObserver +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import ( QuantizationAnnotation, QuantizationSpec, Quantizer, ) -from torch.export import export, export_for_training -from torch.fx import GraphModule @dataclass(eq=True, frozen=True) diff --git a/backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py b/backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py index 8e857f32376..4f78c59927a 100644 --- a/backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py +++ b/backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py @@ -11,7 +11,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dim_order_utils import get_dim_order from executorch.exir.pass_base import ExportPass, PassResult -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions +from torchao.quantization.pt2e import find_sequential_partitions class PermuteMemoryFormatsPass(ExportPass): diff --git a/backends/example/example_operators/utils.py b/backends/example/example_operators/utils.py index 7dca2a3be6a..d9b3a436840 100644 --- a/backends/example/example_operators/utils.py +++ b/backends/example/example_operators/utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation +from torchao.quantization.pt2e.quantizer.quantizer import QuantizationAnnotation def _nodes_are_annotated(node_list): diff --git a/backends/example/example_partitioner.py b/backends/example/example_partitioner.py index 5e9102e999b..7bfe783a4f3 100644 --- a/backends/example/example_partitioner.py +++ b/backends/example/example_partitioner.py @@ -19,9 +19,9 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.graph_module import get_control_flow_submodules -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions from torch.export import ExportedProgram from torch.fx.passes.operator_support import OperatorSupportBase +from torchao.quantization.pt2e import find_sequential_partitions @final diff --git a/backends/example/example_quantizer.py b/backends/example/example_quantizer.py index 74a0057ba4a..c5a3e179695 100644 --- a/backends/example/example_quantizer.py +++ b/backends/example/example_quantizer.py @@ -11,9 +11,12 @@ from executorch.backends.example.example_operators.ops import module_to_annotator from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import OperatorConfig from torch import fx -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +from torchao.quantization.pt2e import ( + find_sequential_partitions, + HistogramObserver, + MinMaxObserver, +) +from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer def get_uint8_tensor_spec(observer_or_fake_quant_ctr): diff --git a/backends/example/test_example_delegate.py b/backends/example/test_example_delegate.py index a382273af07..bc6ad4d7e4c 100644 --- a/backends/example/test_example_delegate.py +++ b/backends/example/test_example_delegate.py @@ -17,10 +17,10 @@ DuplicateDequantNodePass, ) from executorch.exir.delegate import executorch_call_delegate - -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + from torchvision.models.quantization import mobilenet_v2 diff --git a/backends/mediatek/quantizer/annotator.py b/backends/mediatek/quantizer/annotator.py index d250b774af8..efdde09be88 100644 --- a/backends/mediatek/quantizer/annotator.py +++ b/backends/mediatek/quantizer/annotator.py @@ -10,18 +10,18 @@ from torch._ops import OpOverload from torch._subclasses import FakeTensor -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) - from torch.export import export_for_training from torch.fx import Graph, Node from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( SubgraphMatcherWithNameNodeMap, ) +from torchao.quantization.pt2e.quantizer import ( + annotate_input_qspec_map, + annotate_output_qspec as _annotate_output_qspec, + QuantizationAnnotation, +) + from .qconfig import QuantizationConfig @@ -108,7 +108,7 @@ def _annotate_fused_activation_pattern( torch.ops.aten.linear.default, ]: weight_node = producer_node.args[1] - _annotate_input_qspec_map( + annotate_input_qspec_map( producer_node, weight_node, quant_config.weight, @@ -201,7 +201,7 @@ def annotate_affine_ops(node: Node, quant_config: QuantizationConfig) -> None: return weight_node = node.args[1] - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quant_config.weight, @@ -260,5 +260,5 @@ def annotate_embedding_op(node: Node, quant_config: QuantizationConfig) -> None: return wgt_node = node.args[0] - _annotate_input_qspec_map(node, wgt_node, quant_config.activation) + annotate_input_qspec_map(node, wgt_node, quant_config.activation) _mark_as_annotated([node]) diff --git a/backends/mediatek/quantizer/qconfig.py b/backends/mediatek/quantizer/qconfig.py index e16f5e936cb..d9f105cd0d0 100644 --- a/backends/mediatek/quantizer/qconfig.py +++ b/backends/mediatek/quantizer/qconfig.py @@ -10,9 +10,9 @@ import torch -from torch.ao.quantization.fake_quantize import FakeQuantize -from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver -from torch.ao.quantization.quantizer import QuantizationSpec +from torchao.quantization.pt2e.fake_quantize import FakeQuantize +from torchao.quantization.pt2e.observer import MinMaxObserver, PerChannelMinMaxObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec @unique diff --git a/backends/mediatek/quantizer/quantizer.py b/backends/mediatek/quantizer/quantizer.py index 4e78d6dff1a..f9babdec997 100644 --- a/backends/mediatek/quantizer/quantizer.py +++ b/backends/mediatek/quantizer/quantizer.py @@ -4,8 +4,8 @@ # except in compliance with the License. See the license file in the root # directory of this source tree for more details. -from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule +from torchao.quantization.pt2e.quantizer import Quantizer from .._passes.decompose_scaled_dot_product_attention import ( DecomposeScaledDotProductAttention, diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index eff7f513cb9..8a64170b632 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -35,9 +35,9 @@ QuantizationSpec, ) from torch import fx -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver -from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer +from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver +from torchao.quantization.pt2e.quantizer import DerivedQuantizationSpec, Quantizer +from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer class NeutronAtenQuantizer(Quantizer): diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 6797447c50c..b71f0621002 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -14,7 +14,7 @@ from executorch.backends.nxp.quantizer.utils import get_bias_qparams from torch import fx from torch._ops import OpOverload -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, FixedQParamsQuantizationSpec, SharedQuantizationSpec, diff --git a/backends/nxp/quantizer/utils.py b/backends/nxp/quantizer/utils.py index 1effcdff25a..1b941f6e632 100644 --- a/backends/nxp/quantizer/utils.py +++ b/backends/nxp/quantizer/utils.py @@ -14,11 +14,11 @@ import torch from torch import fx from torch._ops import OpOverload -from torch.ao.quantization import ObserverOrFakeQuantize from torch.fx.passes.utils.source_matcher_utils import ( check_subgraphs_connected, SourcePartition, ) +from torchao.quantization.pt2e import ObserverOrFakeQuantize def is_annotated(nodes: List[fx.Node]) -> bool: diff --git a/backends/nxp/tests/executorch_pipeline.py b/backends/nxp/tests/executorch_pipeline.py index 6c452b99baf..fe1cdeaf751 100644 --- a/backends/nxp/tests/executorch_pipeline.py +++ b/backends/nxp/tests/executorch_pipeline.py @@ -20,7 +20,7 @@ to_edge_transform_and_lower, ) from torch import nn -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]): diff --git a/backends/nxp/tests/test_quantizer.py b/backends/nxp/tests/test_quantizer.py index 868a94059b5..dd1b691a18f 100644 --- a/backends/nxp/tests/test_quantizer.py +++ b/backends/nxp/tests/test_quantizer.py @@ -8,7 +8,7 @@ import executorch.backends.nxp.tests.models as models import torch from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def _get_target_name(node): diff --git a/backends/openvino/quantizer/quantizer.py b/backends/openvino/quantizer/quantizer.py index 5532235f573..134ed5cb4ac 100644 --- a/backends/openvino/quantizer/quantizer.py +++ b/backends/openvino/quantizer/quantizer.py @@ -17,12 +17,12 @@ import torch.fx from nncf.common.graph.graph import NNCFGraph # type: ignore[import-untyped] -from torch.ao.quantization.observer import ( +from torchao.quantization.pt2e.observer import ( HistogramObserver, PerChannelMinMaxObserver, UniformQuantizationObserverBase, ) -from torch.ao.quantization.quantizer.quantizer import ( +from torchao.quantization.pt2e.quantizer.quantizer import ( EdgeOrNode, QuantizationAnnotation, QuantizationSpec, diff --git a/backends/openvino/scripts/openvino_build.sh b/backends/openvino/scripts/openvino_build.sh index 83ffd7542f3..e36501c941a 100755 --- a/backends/openvino/scripts/openvino_build.sh +++ b/backends/openvino/scripts/openvino_build.sh @@ -59,6 +59,9 @@ main() { # Build the package pip install . --no-build-isolation + # Install torchao + pip install third-party/ao + else echo "Error: Argument is not valid: $build_type" exit 1 # Exit the script with an error code diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py index b4f14fc28cd..5a41f04ff7b 100644 --- a/backends/qualcomm/_passes/annotate_quant_attrs.py +++ b/backends/qualcomm/_passes/annotate_quant_attrs.py @@ -112,9 +112,9 @@ def _dequant_fold_params(self, n, quant_attrs, param): offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis) param = param.sub(offsets).mul(scales).to(torch.float32).contiguous() elif quant_attrs[QCOM_ENCODING] in [ - exir_ops.edge.pt2e_quant.dequantize_affine.default + exir_ops.edge.torchao.dequantize_affine.default ]: - param = torch.ops.pt2e_quant.dequantize_affine( + param = torch.ops.torchao.dequantize_affine( param, block_size=quant_attrs[QCOM_BLOCK_SIZE], scale=quant_attrs[QCOM_SCALE], diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index c98f27db120..c4b730dc5b2 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -131,8 +131,8 @@ def get_to_edge_transform_passes( from executorch.backends.qualcomm._passes import utils from executorch.exir.dialects._ops import ops as exir_ops - utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default) - utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default) + utils.q_ops.add(exir_ops.edge.torchao.quantize_affine.default) + utils.dq_ops.add(exir_ops.edge.torchao.dequantize_affine.default) passes_job = ( passes_job if passes_job is not None else get_capture_program_passes() diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 7965a30caea..e99d6b2e620 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -242,8 +242,8 @@ def get_quant_encoding_conf( ) # TODO: refactor this when target could be correctly detected per_block_encoding = { - exir_ops.edge.pt2e_quant.quantize_affine.default, - exir_ops.edge.pt2e_quant.dequantize_affine.default, + exir_ops.edge.torchao.quantize_affine.default, + exir_ops.edge.torchao.dequantize_affine.default, } if quant_attrs[QCOM_ENCODING] in per_block_encoding: return self.make_qnn_per_block_config(node, quant_attrs) @@ -271,7 +271,7 @@ def get_quant_tensor_value( axis_order.index(x) for x in range(len(axis_order)) ) tensor = tensor.permute(origin_order) - tensor = torch.ops.pt2e_quant.quantize_affine( + tensor = torch.ops.torchao.quantize_affine( tensor, block_size=quant_attrs[QCOM_BLOCK_SIZE], scale=quant_attrs[QCOM_SCALE], diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 816d1ac1d9b..05bbd1ff970 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -57,7 +57,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: torch.ops.aten.upsample_bicubic2d.vec, # This request is ignored because it is in a blocklist. Refer to exir/program/_program.py torch.ops.aten.unbind.int, - torch.ops.pt2e_quant.quantize_affine.default, - torch.ops.pt2e_quant.dequantize_affine.default, + torch.ops.torchao.quantize_affine.default, + torch.ops.torchao.dequantize_affine.default, ] return do_not_decompose diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 469a801feeb..545039555cf 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -12,20 +12,20 @@ from torch._ops import OpOverload from torch._subclasses import FakeTensor -from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize +from torch.fx import Node +from torchao.quantization.pt2e.fake_quantize import FixedQParamsFakeQuantize -from torch.ao.quantization.observer import FixedQParamsObserver -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e.observer import FixedQParamsObserver +from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, +from torchao.quantization.pt2e.quantizer.utils import ( + annotate_input_qspec_map, + annotate_output_qspec, ) -from torch.fx import Node from .qconfig import ( get_16a16w_qnn_ptq_config, @@ -618,19 +618,19 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No return # TODO current only support 16a16w - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, quantization_config.input_activation, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quantization_config.input_activation, ) nodes_to_mark_annotated = [node] - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) @@ -819,25 +819,25 @@ def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) -> if _is_annotated([node]): return - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, quantization_config.input_activation, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quantization_config.weight, ) nodes_to_mark_annotated = [node, weight_node] if bias_node: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias_node, quantization_config.bias, ) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) @@ -1002,12 +1002,12 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None if _is_annotated([node]): return - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, quantization_config.input_activation, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quantization_config.weight, @@ -1018,9 +1018,9 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None bias_config = quantization_config.bias(node) else: bias_config = quantization_config.bias - _annotate_input_qspec_map(node, bias_node, bias_config) + annotate_input_qspec_map(node, bias_node, bias_config) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack. @@ -1038,14 +1038,14 @@ def annotate_batch_and_instance_norm( return annotated_args = [act] - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act, quantization_config.input_activation, ) # QNN requires uint8 instead of int8 in 'weight' config if weight is not None: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight, quantization_config.input_activation, @@ -1053,14 +1053,14 @@ def annotate_batch_and_instance_norm( annotated_args.append(weight) if bias is not None: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias, quantization_config.bias, ) annotated_args.append(bias) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated([node, *annotated_args]) @@ -1070,7 +1070,7 @@ def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> Non return if _is_float_tensor(node): - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated([node]) @@ -1086,32 +1086,32 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> return input_act_qspec = quantization_config.input_activation - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, input_act_qspec, ) if input_act_qspec.dtype == torch.int32: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, get_16a16w_qnn_ptq_config().weight, ) else: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, input_act_qspec, ) nodes_to_mark_annotated = [node, weight_node] if bias_node: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias_node, quantization_config.bias, ) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index bda91609f1c..2771645c3e0 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -17,13 +17,13 @@ QuantizationConfig, ) from executorch.exir.dialects._ops import ops as exir_ops -from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver -from torch.ao.quantization.quantizer import ( +from torch.fx import Node +from torchao.quantization.pt2e.observer import FixedQParamsObserver, MinMaxObserver +from torchao.quantization.pt2e.quantizer import ( QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) -from torch.fx import Node def annotate_mimi_decoder(gm: torch.fx.GraphModule): diff --git a/backends/qualcomm/quantizer/observers/per_block_param_observer.py b/backends/qualcomm/quantizer/observers/per_block_param_observer.py index e60f15c6d9c..982f3fdbb65 100644 --- a/backends/qualcomm/quantizer/observers/per_block_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_block_param_observer.py @@ -7,12 +7,12 @@ from typing import Tuple import torch -from torch.ao.quantization.observer import MappingType, PerBlock -from torch.ao.quantization.pt2e._affine_quantization import ( +from torchao.quantization.pt2e._affine_quantization import ( _get_reduction_params, AffineQuantizedMinMaxObserver, choose_qparams_affine_with_min_max, ) +from torchao.quantization.pt2e.observer import MappingType, PerBlock class PerBlockParamObserver(AffineQuantizedMinMaxObserver): diff --git a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py index 3c04e620308..cf57c94b72e 100644 --- a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch -from torch.ao.quantization.observer import UniformQuantizationObserverBase +from torchao.quantization.pt2e.observer import UniformQuantizationObserverBase # TODO move to torch/ao/quantization/observer.py. diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 67968363eb6..e0a838cc32e 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -7,18 +7,21 @@ PerBlockParamObserver, ) from torch import Tensor -from torch.ao.quantization.fake_quantize import ( +from torch.fx import Node +from torchao.quantization.pt2e.fake_quantize import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, ) -from torch.ao.quantization.observer import ( +from torchao.quantization.pt2e.observer import ( MinMaxObserver, MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, PerChannelMinMaxObserver, ) -from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec -from torch.fx import Node +from torchao.quantization.pt2e.quantizer import ( + DerivedQuantizationSpec, + QuantizationSpec, +) @dataclass(eq=True) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 8e65607dd84..4a1bca70add 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -9,11 +9,12 @@ from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple import torch +import torchao from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager from torch._ops import OpOverload -from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule +from torchao.quantization.pt2e.quantizer import Quantizer from .annotators import OP_ANNOTATOR @@ -131,7 +132,7 @@ class ModuleQConfig: is_conv_per_channel: bool = False is_linear_per_channel: bool = False act_observer: Optional[ - torch.ao.quantization.observer.UniformQuantizationObserverBase + torchao.quantization.pt2e.observer.UniformQuantizationObserverBase ] = None def __post_init__(self): diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 695c846de05..bafa753f982 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -14,6 +14,8 @@ import numpy as np import torch +import torchao + from executorch import exir from executorch.backends.qualcomm.qnn_preprocess import QnnBackend from executorch.backends.qualcomm.quantizer.quantizer import ModuleQConfig, QuantDtype @@ -43,12 +45,12 @@ from executorch.exir.pass_base import ExportPass from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from executorch.exir.program import ExecutorchProgram, ExecutorchProgramManager -from torch.ao.quantization.quantize_pt2e import ( +from torch.fx.passes.infra.pass_base import PassResult +from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, prepare_qat_pt2e, ) -from torch.fx.passes.infra.pass_base import PassResult def generate_context_binary( @@ -536,8 +538,8 @@ def get_qdq_module( torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_channel.default, torch.ops.quantized_decomposed.dequantize_per_channel.default, - torch.ops.pt2e_quant.quantize_affine.default, - torch.ops.pt2e_quant.dequantize_affine.default, + torch.ops.torchao.quantize_affine.default, + torch.ops.torchao.dequantize_affine.default, } if not bypass_check: self.assertTrue(nodes.intersection(q_and_dq)) @@ -568,7 +570,7 @@ def get_prepared_qat_module( quantizer.set_submodule_qconfig_list(submodule_qconfig_list) prepared = prepare_qat_pt2e(m, quantizer) - return torch.ao.quantization.move_exported_model_to_train(prepared) + return torchao.quantization.pt2e.move_exported_model_to_train(prepared) def get_converted_sgd_trained_module( self, @@ -583,7 +585,7 @@ def get_converted_sgd_trained_module( optimizer.zero_grad() loss.backward() optimizer.step() - return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared) + return torchao.quantization.pt2e.quantize_pt2e.convert_pt2e(prepared) def split_graph(self, division: int): class SplitGraph(ExportPass): diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 3653cd3176f..c6ba6a4a972 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -607,8 +607,8 @@ def skip_annotation( from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( flatbuffer_to_option, ) - from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner + from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def prepare_subgm(subgm, subgm_name): # prepare current submodule for quantization annotation diff --git a/backends/transforms/duplicate_dynamic_quant_chain.py b/backends/transforms/duplicate_dynamic_quant_chain.py index 2ca65eec45f..d7f119c4cf6 100644 --- a/backends/transforms/duplicate_dynamic_quant_chain.py +++ b/backends/transforms/duplicate_dynamic_quant_chain.py @@ -9,14 +9,11 @@ import torch -from torch.ao.quantization.pt2e.utils import ( - _filter_sym_size_users, - _is_valid_annotation, -) - from torch.fx.node import map_arg from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torchao.quantization.pt2e.quantizer import is_valid_annotation +from torchao.quantization.pt2e.utils import _filter_sym_size_users logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -129,7 +126,7 @@ def _maybe_duplicate_dynamic_quantize_chain( dq_node_users = list(dq_node.users.copy()) for user in dq_node_users: annotation = user.meta.get("quantization_annotation", None) - if not _is_valid_annotation(annotation): + if not is_valid_annotation(annotation): return with gm.graph.inserting_after(dq_node): new_node = gm.graph.node_copy(dq_node) diff --git a/backends/transforms/test/test_duplicate_dynamic_quant_chain.py b/backends/transforms/test/test_duplicate_dynamic_quant_chain.py index ab965dd347d..79bc56f8780 100644 --- a/backends/transforms/test/test_duplicate_dynamic_quant_chain.py +++ b/backends/transforms/test/test_duplicate_dynamic_quant_chain.py @@ -15,7 +15,6 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e # TODO: Move away from using torch's internal testing utils from torch.testing._internal.common_quantization import ( @@ -23,6 +22,7 @@ QuantizationTestCase, TestHelperModules, ) +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e class MyTestHelperModules: diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index b2f1a658040..736649a016b 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -18,9 +18,9 @@ propagate_annotation, QuantizationConfig, ) -from torch.ao.quantization.observer import PerChannelMinMaxObserver -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer from torch.fx import Node +from torchao.quantization.pt2e.observer import PerChannelMinMaxObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer __all__ = [ diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index b57710974e8..41866fe4a46 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -23,11 +23,11 @@ EdgeProgramManager, ExecutorchProgramManager, ) +from torch.export import Dim, export, export_for_training, ExportedProgram -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer import Quantizer -from torch.export import Dim, export, export_for_training, ExportedProgram +from torchao.quantization.pt2e.quantizer import Quantizer ctypes.CDLL("libvulkan.so.1") diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 7572ebd5a5a..ff9e2d85a96 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -17,8 +17,8 @@ format_target_name, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer import Quantizer +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer ################### ## Common Models ## diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py index 229b75f0ed9..66f1fb14750 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -17,11 +17,11 @@ propagate_annotation, QuantizationConfig, ) -from torch.ao.quantization.fake_quantize import ( +from torchao.quantization.pt2e.fake_quantize import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, ) -from torch.ao.quantization.observer import ( +from torchao.quantization.pt2e.observer import ( HistogramObserver, MinMaxObserver, MovingAverageMinMaxObserver, @@ -29,13 +29,13 @@ PerChannelMinMaxObserver, PlaceholderObserver, ) -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.utils import _get_module_name_filter +from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer +from torchao.quantization.pt2e.quantizer.utils import get_module_name_filter if TYPE_CHECKING: - from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor from torch.fx import Node + from torchao.quantization.pt2e import ObserverOrFakeQuantizeConstructor __all__ = [ @@ -140,7 +140,7 @@ def get_symmetric_quantization_config( weight_qscheme = ( torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric ) - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = ( MinMaxObserver ) if is_qat: @@ -228,7 +228,7 @@ def _get_not_module_type_or_name_filter( tp_list: list[Callable], module_name_list: list[str] ) -> Callable[[Node], bool]: module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] - module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + module_name_list_filters = [get_module_name_filter(m) for m in module_name_list] def not_module_type_or_name_filter(n: Node) -> bool: return not any(f(n) for f in module_type_filters + module_name_list_filters) @@ -421,7 +421,7 @@ def _annotate_for_quantization_config( module_name_list = list(self.module_name_config.keys()) for module_name, config in self.module_name_config.items(): self._annotate_all_patterns( - model, config, _get_module_name_filter(module_name) + model, config, get_module_name_filter(module_name) ) tp_list = list(self.module_type_config.keys()) diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index 4b961bef81d..ff4114cfdd4 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -9,26 +9,26 @@ from executorch.backends.xnnpack.utils.utils import is_depthwise_conv from torch._subclasses import FakeTensor from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix -from torch.ao.quantization.pt2e.export_utils import _WrapperModule -from torch.ao.quantization.pt2e.utils import ( - _get_aten_graph_module_for_pattern, - _is_conv_node, - _is_conv_transpose_node, +from torch.fx import Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap, ) -from torch.ao.quantization.quantizer import ( +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions +from torchao.quantization.pt2e.export_utils import WrapperModule +from torchao.quantization.pt2e.quantizer import ( QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, +from torchao.quantization.pt2e.quantizer.utils import ( + annotate_input_qspec_map, + annotate_output_qspec, ) -from torch.fx import Node -from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( - SubgraphMatcherWithNameNodeMap, +from torchao.quantization.pt2e.utils import ( + _get_aten_graph_module_for_pattern, + _is_conv_node, + _is_conv_transpose_node, ) -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions __all__ = [ "OperatorConfig", @@ -204,25 +204,25 @@ def _annotate_linear( bias_node = node.args[2] if _is_annotated([node]) is False: # type: ignore[list-item] - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, input_act_qspec, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, weight_qspec, ) nodes_to_mark_annotated = [node, weight_node] if bias_node: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias_node, bias_qspec, ) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, output_act_qspec) + annotate_output_qspec(node, output_act_qspec) _mark_nodes_as_annotated(nodes_to_mark_annotated) annotated_partitions.append(nodes_to_mark_annotated) @@ -572,7 +572,7 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): "output": output, } - return _WrapperModule(_conv_bn) + return WrapperModule(_conv_bn) # Needed for matching, otherwise the matches gets filtered out due to unused # nodes returned by batch norm diff --git a/backends/xnnpack/test/ops/test_check_quant_params.py b/backends/xnnpack/test/ops/test_check_quant_params.py index d05b1fce540..8be59aab50e 100644 --- a/backends/xnnpack/test/ops/test_check_quant_params.py +++ b/backends/xnnpack/test/ops/test_check_quant_params.py @@ -9,8 +9,8 @@ ) from executorch.backends.xnnpack.utils.utils import get_param_tensor from executorch.exir import to_edge_transform_and_lower -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export_for_training +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e class TestCheckQuantParams(unittest.TestCase): diff --git a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py index 4243441118e..2f1d81e95b5 100644 --- a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py +++ b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py @@ -10,21 +10,12 @@ from typing import Dict, Tuple import torch +import torchao from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.ao.quantization import ( - compare_results, - CUSTOM_KEY, - default_per_channel_symmetric_qnnpack_qconfig, - extract_results_from_loggers, - generate_numeric_debug_handle, - NUMERIC_DEBUG_HANDLE_KEY, - observer, - prepare_for_propagation_comparison, -) -from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process +from torch.ao.quantization import default_per_channel_symmetric_qnnpack_qconfig from torch.ao.quantization.qconfig import ( float_qparams_weight_only_qconfig, per_channel_weight_observer_range_neg_127_to_127, @@ -32,18 +23,9 @@ weight_observer_range_neg_127_to_127, ) from torch.ao.quantization.qconfig_mapping import QConfigMapping -from torch.ao.quantization.quantize_pt2e import ( - convert_pt2e, - prepare_pt2e, - prepare_qat_pt2e, -) -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer -from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, - PT2EQuantizationTestCase, TestHelperModules, ) from torch.testing._internal.common_utils import ( @@ -51,9 +33,42 @@ TemporaryFileName, TestCase, ) +from torchao.quantization.pt2e import ( + compare_results, + CUSTOM_KEY, + extract_results_from_loggers, + generate_numeric_debug_handle, + NUMERIC_DEBUG_HANDLE_KEY, + observer, + prepare_for_propagation_comparison, +) +from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process +from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer +from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer +from torchao.testing.pt2e.utils import PT2EQuantizationTestCase class TestQuantizePT2E(PT2EQuantizationTestCase): + def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): + # resetting dynamo cache + torch._dynamo.reset() + + gm = export_for_training(m, example_inputs, strict=True).module() + assert isinstance(gm, torch.fx.GraphModule) + if is_qat: + gm = prepare_qat_pt2e(gm, quantizer) + else: + gm = prepare_pt2e(gm, quantizer) + gm(*example_inputs) + gm = convert_pt2e(gm) + return gm + def _get_pt2e_quantized_linear( self, is_per_channel: bool = False ) -> torch.fx.GraphModule: @@ -287,7 +302,7 @@ def test_embedding_conv_linear_quantization(self) -> None: [embedding_quantizer, dynamic_quantizer, static_quantizer] ) - act_affine_quant_obs = observer.PlaceholderObserver.with_args( + act_affine_quant_obs = torch.ao.quantization.PlaceholderObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=-128, @@ -404,7 +419,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: m.train() # After wrapping: does not error and swaps the ops accordingly - torch.ao.quantization.allow_exported_model_train_eval(m) # pyre-ignore[6] + torchao.quantization.pt2e.allow_exported_model_train_eval(m) # pyre-ignore[6] m.eval() _assert_ops_are_correct(m, train=False) # pyre-ignore[6] m.train() @@ -419,7 +434,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: m.train() # After prepare and after wrapping: does not error and swaps the ops accordingly - torch.ao.quantization.allow_exported_model_train_eval(m) + torchao.quantization.pt2e.allow_exported_model_train_eval(m) m.eval() _assert_ops_are_correct(m, train=False) m.train() @@ -433,7 +448,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: m.train() # After convert and after wrapping: does not error and swaps the ops accordingly - torch.ao.quantization.allow_exported_model_train_eval(m) + torchao.quantization.pt2e.allow_exported_model_train_eval(m) m.eval() _assert_ops_are_correct(m, train=False) m.train() @@ -783,7 +798,7 @@ def test_extract_results_from_loggers(self) -> None: ref_results = extract_results_from_loggers(m_ref_logger) quant_results = extract_results_from_loggers(m_quant_logger) comparison_results = compare_results( - ref_results, + ref_results, # pyre-ignore[6] quant_results, # pyre-ignore[6] ) for node_summary in comparison_results.values(): diff --git a/backends/xnnpack/test/quantizer/test_representation.py b/backends/xnnpack/test/quantizer/test_representation.py index e52bbbd7ae7..817f7f9e368 100644 --- a/backends/xnnpack/test/quantizer/test_representation.py +++ b/backends/xnnpack/test/quantizer/test_representation.py @@ -8,8 +8,6 @@ XNNPACKQuantizer, ) from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer import Quantizer from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, @@ -17,6 +15,8 @@ skipIfNoQNNPACK, TestHelperModules, ) +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer @skipIfNoQNNPACK diff --git a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py index 9b0515bad58..0a317ad8822 100644 --- a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py +++ b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py @@ -12,7 +12,6 @@ from torch.ao.quantization import ( default_dynamic_fake_quant, default_dynamic_qconfig, - observer, QConfig, QConfigMapping, ) @@ -28,16 +27,16 @@ convert_to_reference_fx, prepare_fx, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, - PT2EQuantizationTestCase, skip_if_no_torchvision, skipIfNoQNNPACK, TestHelperModules, ) from torch.testing._internal.common_quantized import override_quantized_engine +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.testing.pt2e.utils import PT2EQuantizationTestCase @skipIfNoQNNPACK @@ -575,7 +574,7 @@ def test_dynamic_linear(self): torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } - act_affine_quant_obs = observer.PlaceholderObserver.with_args( + act_affine_quant_obs = torch.ao.quantization.PlaceholderObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=-128, @@ -621,7 +620,7 @@ def test_dynamic_linear_int4_weight(self): torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } - act_affine_quant_obs = observer.PlaceholderObserver.with_args( + act_affine_quant_obs = torch.ao.quantization.PlaceholderObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=-128, @@ -718,7 +717,7 @@ def test_dynamic_linear_with_conv(self): torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, } - act_affine_quant_obs = observer.PlaceholderObserver.with_args( + act_affine_quant_obs = torch.ao.quantization.PlaceholderObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=-128, diff --git a/backends/xnnpack/test/test_xnnpack_utils.py b/backends/xnnpack/test/test_xnnpack_utils.py index 3ff2f0e4c1e..e6c97545d82 100644 --- a/backends/xnnpack/test/test_xnnpack_utils.py +++ b/backends/xnnpack/test/test_xnnpack_utils.py @@ -47,7 +47,6 @@ from torch.ao.quantization import ( # @manual default_per_channel_symmetric_qnnpack_qconfig, - PlaceholderObserver, QConfig, QConfigMapping, ) @@ -55,12 +54,6 @@ from torch.ao.quantization.backend_config.executorch import ( get_executorch_backend_config, ) - -from torch.ao.quantization.observer import ( - per_channel_weight_observer_range_neg_127_to_127, - # default_weight_observer, - weight_observer_range_neg_127_to_127, -) from torch.ao.quantization.qconfig_mapping import ( _get_default_qconfig_mapping_with_default_qconfig, _get_symmetric_qnnpack_qconfig_mapping, @@ -70,11 +63,18 @@ _convert_to_reference_decomposed_fx, prepare_fx, ) - -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export_for_training from torch.testing import FileCheck +from torchao.quantization.pt2e import PlaceholderObserver + +from torchao.quantization.pt2e.observer import ( + per_channel_weight_observer_range_neg_127_to_127, + # default_weight_observer, + weight_observer_range_neg_127_to_127, +) + +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def randomize_bn(num_features: int, dimensionality: int = 2) -> torch.nn.Module: diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index cbce817cf4b..fd48837bd72 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -55,11 +55,11 @@ ) from executorch.exir.program._program import _transform from torch._export.pass_base import PassType -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.quantizer import Quantizer from torch.export import export, ExportedProgram from torch.testing import FileCheck from torch.utils._pytree import tree_flatten +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer.quantizer import Quantizer class Stage(ABC): diff --git a/docs/source/backends-coreml.md b/docs/source/backends-coreml.md index 29a4b331be6..1292c80148d 100644 --- a/docs/source/backends-coreml.md +++ b/docs/source/backends-coreml.md @@ -104,7 +104,7 @@ import torchvision.models as models from torchvision.models.mobilenetv2 import MobileNet_V2_Weights from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer from executorch.backends.apple.coreml.partition import CoreMLPartitioner -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from executorch.exir import to_edge_transform_and_lower from executorch.backends.apple.coreml.compiler import CoreMLBackend diff --git a/docs/source/backends-xnnpack.md b/docs/source/backends-xnnpack.md index db1c055dc9c..85919952988 100644 --- a/docs/source/backends-xnnpack.md +++ b/docs/source/backends-xnnpack.md @@ -1,6 +1,6 @@ # XNNPACK Backend -The XNNPACK delegate is the ExecuTorch solution for CPU execution on mobile CPUs. [XNNPACK](https://github.com/google/XNNPACK/tree/master) is a library that provides optimized kernels for machine learning operators on Arm and x86 CPUs. +The XNNPACK delegate is the ExecuTorch solution for CPU execution on mobile CPUs. [XNNPACK](https://github.com/google/XNNPACK/tree/master) is a library that provides optimized kernels for machine learning operators on Arm and x86 CPUs. ## Features @@ -18,7 +18,7 @@ The XNNPACK delegate is the ExecuTorch solution for CPU execution on mobile CPUs ## Development Requirements -The XNNPACK delegate does not introduce any development system requirements beyond those required by +The XNNPACK delegate does not introduce any development system requirements beyond those required by the core ExecuTorch runtime. ---- @@ -63,7 +63,7 @@ After generating the XNNPACK-delegated .pte, the model can be tested from Python ## Quantization -The XNNPACK delegate can also be used as a backend to execute symmetrically quantized models. To quantize a PyTorch model for the XNNPACK backend, use the `XNNPACKQuantizer`. `Quantizers` are backend specific, which means the `XNNPACKQuantizer` is configured to quantize models to leverage the quantized operators offered by the XNNPACK Library. +The XNNPACK delegate can also be used as a backend to execute symmetrically quantized models. To quantize a PyTorch model for the XNNPACK backend, use the `XNNPACKQuantizer`. `Quantizers` are backend specific, which means the `XNNPACKQuantizer` is configured to quantize models to leverage the quantized operators offered by the XNNPACK Library. ### Supported Quantization Schemes The XNNPACK delegate supports the following quantization schemes: @@ -94,8 +94,8 @@ from torchvision.models.mobilenetv2 import MobileNet_V2_Weights from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import XNNPACKQuantizer from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.exir import to_edge_transform_and_lower -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.xnnpack_quantizer import get_symmetric_quantization_config +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import get_symmetric_quantization_config model = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() sample_inputs = (torch.randn(1, 3, 224, 224), ) diff --git a/docs/source/llm/getting-started.md b/docs/source/llm/getting-started.md index 152162841e4..7d54f4d2dde 100644 --- a/docs/source/llm/getting-started.md +++ b/docs/source/llm/getting-started.md @@ -619,7 +619,7 @@ from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e ``` ```python diff --git a/docs/source/tutorial-xnnpack-delegate-lowering.md b/docs/source/tutorial-xnnpack-delegate-lowering.md index add60a12deb..12793533766 100644 --- a/docs/source/tutorial-xnnpack-delegate-lowering.md +++ b/docs/source/tutorial-xnnpack-delegate-lowering.md @@ -85,7 +85,7 @@ sample_inputs = (torch.randn(1, 3, 224, 224), ) mobilenet_v2 = export_for_training(mobilenet_v2, sample_inputs).module() # 2-stage export for quantization path -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, diff --git a/docs/source/tutorials_source/export-to-executorch-tutorial.py b/docs/source/tutorials_source/export-to-executorch-tutorial.py index de42cb51bce..2ca6a207d17 100644 --- a/docs/source/tutorials_source/export-to-executorch-tutorial.py +++ b/docs/source/tutorials_source/export-to-executorch-tutorial.py @@ -200,7 +200,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer) # type: ignore[arg-type] diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 73fa4b24d4e..25f2a26ccf7 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -56,10 +56,10 @@ from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.extension.export_util.utils import save_pte_program from tabulate import tabulate +from torch.utils.data import DataLoader # Quantize model if required using the standard export quantizaion flow. -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.utils.data import DataLoader +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from ..models import MODEL_NAME_TO_MODEL from ..models.model_factory import EagerModelFactory diff --git a/examples/arm/ethos_u_minimal_example.ipynb b/examples/arm/ethos_u_minimal_example.ipynb index 77be8c22447..146a586d0ab 100644 --- a/examples/arm/ethos_u_minimal_example.ipynb +++ b/examples/arm/ethos_u_minimal_example.ipynb @@ -84,7 +84,7 @@ " EthosUQuantizer,\n", " get_symmetric_quantization_config,\n", ")\n", - "from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e\n", + "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", "\n", "target = \"ethos-u55-128\"\n", "\n", diff --git a/examples/mediatek/aot_utils/oss_utils/utils.py b/examples/mediatek/aot_utils/oss_utils/utils.py index d286a380d5c..bf7b25f07c2 100755 --- a/examples/mediatek/aot_utils/oss_utils/utils.py +++ b/examples/mediatek/aot_utils/oss_utils/utils.py @@ -15,7 +15,7 @@ Precision, ) from executorch.exir.backend.backend_details import CompileSpec -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def build_executorch_binary( diff --git a/examples/mediatek/model_export_scripts/llama.py b/examples/mediatek/model_export_scripts/llama.py index 34e935bb03b..6a098e2a9b1 100644 --- a/examples/mediatek/model_export_scripts/llama.py +++ b/examples/mediatek/model_export_scripts/llama.py @@ -43,7 +43,7 @@ Precision, ) from executorch.exir.backend.backend_details import CompileSpec -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from tqdm import tqdm warnings.filterwarnings("ignore") diff --git a/examples/models/moshi/mimi/test_mimi.py b/examples/models/moshi/mimi/test_mimi.py index 7e2cfb14c49..be3c075913d 100644 --- a/examples/models/moshi/mimi/test_mimi.py +++ b/examples/models/moshi/mimi/test_mimi.py @@ -19,9 +19,9 @@ from huggingface_hub import hf_hub_download from moshi.models import loaders -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export, ExportedProgram from torch.utils._pytree import tree_flatten +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e proxies = { "http": "http://fwdproxy:8080", diff --git a/examples/models/phi-3-mini/export_phi-3-mini.py b/examples/models/phi-3-mini/export_phi-3-mini.py index 11c2f3834eb..246b3ccd6c6 100644 --- a/examples/models/phi-3-mini/export_phi-3-mini.py +++ b/examples/models/phi-3-mini/export_phi-3-mini.py @@ -20,8 +20,8 @@ ) from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config from executorch.exir import to_edge -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export_for_training +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from transformers import Phi3ForCausalLM diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 6f7bdac8e15..08a4bde779b 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -81,8 +81,8 @@ from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer -from torch.ao.quantization.observer import MinMaxObserver -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.observer import MinMaxObserver +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e sys.setrecursionlimit(4096) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" diff --git a/examples/qualcomm/oss_scripts/moshi/mimi.py b/examples/qualcomm/oss_scripts/moshi/mimi.py index 6b59a71ae64..d67d378f0ce 100644 --- a/examples/qualcomm/oss_scripts/moshi/mimi.py +++ b/examples/qualcomm/oss_scripts/moshi/mimi.py @@ -37,7 +37,7 @@ from huggingface_hub import hf_hub_download from moshi.models import loaders -from torch.ao.quantization.observer import MinMaxObserver +from torchao.quantization.pt2e.observer import MinMaxObserver def seed_all(seed): diff --git a/examples/qualcomm/scripts/export_example.py b/examples/qualcomm/scripts/export_example.py index acf8a9ab468..515fdda8b41 100644 --- a/examples/qualcomm/scripts/export_example.py +++ b/examples/qualcomm/scripts/export_example.py @@ -16,7 +16,7 @@ from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.extension.export_util.utils import save_pte_program -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def main() -> None: diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index d8dab88e998..4ed51a96340 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -19,6 +19,7 @@ import numpy as np import torch +import torchao from executorch.backends.qualcomm.quantizer.quantizer import ( ModuleQConfig, QnnQuantizer, @@ -33,8 +34,8 @@ ) from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass -from torch.ao.quantization.observer import MovingAverageMinMaxObserver -from torch.ao.quantization.quantize_pt2e import ( +from torchao.quantization.pt2e.observer import MovingAverageMinMaxObserver +from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, prepare_qat_pt2e, @@ -231,7 +232,7 @@ def ptq_calibrate(captured_model, quantizer, dataset): def qat_train(ori_model, captured_model, quantizer, dataset): data, targets = dataset - annotated_model = torch.ao.quantization.move_exported_model_to_train( + annotated_model = torchao.quantization.pt2e.move_exported_model_to_train( prepare_qat_pt2e(captured_model, quantizer) ) optimizer = torch.optim.SGD(annotated_model.parameters(), lr=0.00001) @@ -240,7 +241,9 @@ def qat_train(ori_model, captured_model, quantizer, dataset): print(f"Epoch {i}") if i > 3: # Freeze quantizer parameters - annotated_model.apply(torch.ao.quantization.disable_observer) + annotated_model.apply( + torchao.quantization.pt2e.fake_quantize.disable_observer + ) if i > 2: # Freeze batch norm mean and variance estimates annotated_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) @@ -251,8 +254,8 @@ def qat_train(ori_model, captured_model, quantizer, dataset): loss.backward() optimizer.step() - return torch.ao.quantization.quantize_pt2e.convert_pt2e( - torch.ao.quantization.move_exported_model_to_eval(annotated_model) + return torchao.quantization.quantize_pt2e.convert_pt2e( + torchao.quantization.move_exported_model_to_eval(annotated_model) ) diff --git a/examples/xnnpack/quantization/example.py b/examples/xnnpack/quantization/example.py index 90a6b94d02b..93831ab8252 100644 --- a/examples/xnnpack/quantization/example.py +++ b/examples/xnnpack/quantization/example.py @@ -29,7 +29,7 @@ _convert_to_reference_decomposed_fx, prepare_fx, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from ...models import MODEL_NAME_TO_MODEL from ...models.model_factory import EagerModelFactory diff --git a/examples/xnnpack/quantization/utils.py b/examples/xnnpack/quantization/utils.py index 9e49f15a99d..d7648daf5da 100644 --- a/examples/xnnpack/quantization/utils.py +++ b/examples/xnnpack/quantization/utils.py @@ -11,7 +11,7 @@ XNNPACKQuantizer, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from .. import QuantType diff --git a/exir/backend/test/demos/test_xnnpack_qnnpack.py b/exir/backend/test/demos/test_xnnpack_qnnpack.py index 5cbd7f7f659..6038f936ca0 100644 --- a/exir/backend/test/demos/test_xnnpack_qnnpack.py +++ b/exir/backend/test/demos/test_xnnpack_qnnpack.py @@ -28,13 +28,13 @@ _load_for_executorch_from_buffer, ) from executorch.extension.pytree import tree_flatten -from torch.ao.quantization.backend_config.executorch import ( - get_executorch_backend_config, -) -from torch.ao.quantization.observer import ( +from torch.ao.quantization import ( default_dynamic_quant_observer, default_per_channel_weight_observer, ) +from torch.ao.quantization.backend_config.executorch import ( + get_executorch_backend_config, +) from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping from torch.ao.quantization.quantize_fx import ( _convert_to_reference_decomposed_fx, diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index b87ae2dfb58..4f7bccbbf92 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -41,10 +41,6 @@ from torch.ao.quantization.backend_config.executorch import ( get_executorch_backend_config, ) -from torch.ao.quantization.observer import ( - default_dynamic_quant_observer, - default_per_channel_weight_observer, -) from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping from torch.ao.quantization.quantize_fx import ( _convert_to_reference_decomposed_fx, @@ -55,6 +51,10 @@ from torch.export.exported_program import ExportGraphSignature from torch.fx import Graph, GraphModule, Node from torch.nn import functional as F +from torchao.quantization.pt2e.observer import ( + default_dynamic_quant_observer, + default_per_channel_weight_observer, +) torch.ops.load_library("//executorch/kernels/portable:custom_ops_generated_lib") diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 6618c729987..0e32b6c6870 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -71,9 +71,6 @@ from functorch.experimental import control_flow from torch import nn - -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer import QuantizationSpec from torch.export import export from torch.export.graph_signature import InputKind, InputSpec, TensorArgument from torch.fx import GraphModule, subgraph_rewriter @@ -82,6 +79,9 @@ from torch.testing import FileCheck from torch.utils import _pytree as pytree +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import QuantizationSpec + # pyre-ignore def collect_ops(gm: torch.fx.GraphModule): @@ -1168,7 +1168,7 @@ def forward(self, query, key, value): ).module() # 8w16a quantization - from torch.ao.quantization.observer import ( + from torchao.quantization.pt2e.observer import ( MinMaxObserver, PerChannelMinMaxObserver, ) diff --git a/exir/tests/test_quantization.py b/exir/tests/test_quantization.py index 0a0a85077bb..2c66787c7c4 100644 --- a/exir/tests/test_quantization.py +++ b/exir/tests/test_quantization.py @@ -19,18 +19,17 @@ from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.spec_prop_pass import SpecPropPass from torch.ao.ns.fx.utils import compute_sqnr -from torch.ao.quantization import QConfigMapping # @manual from torch.ao.quantization.backend_config import get_executorch_backend_config from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig -from torch.ao.quantization.quantize_fx import prepare_fx -from torch.ao.quantization.quantize_pt2e import ( +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quantize_fx import ( # @manual _convert_to_reference_decomposed_fx, - convert_pt2e, - prepare_pt2e, + prepare_fx, ) from torch.export import export from torch.testing import FileCheck from torch.testing._internal.common_quantized import override_quantized_engine +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e # load executorch out variant ops torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib") diff --git a/exir/tests/test_quantize_io_pass.py b/exir/tests/test_quantize_io_pass.py index ddc0294ba68..f670594616a 100644 --- a/exir/tests/test_quantize_io_pass.py +++ b/exir/tests/test_quantize_io_pass.py @@ -20,8 +20,8 @@ QuantizeOutputs, ) from executorch.exir.tensor import get_scalar_type -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.testing import FileCheck +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e op_str = { "q": "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default", diff --git a/export/export.py b/export/export.py index 7ea4de20a9a..8e9417aa965 100644 --- a/export/export.py +++ b/export/export.py @@ -13,10 +13,10 @@ from executorch.runtime import Runtime, Verification from tabulate import tabulate from torch import nn -from torch.ao.quantization import allow_exported_model_train_eval -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import ExportedProgram from torchao.quantization import quantize_ +from torchao.quantization.pt2e import allow_exported_model_train_eval +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.utils import unwrap_tensor_subclass from .recipe import ExportRecipe diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 323311caeea..3905051ed57 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -35,11 +35,11 @@ from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes from pytorch_tokenizers import get_tokenizer -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer from torch.export import export_for_training, ExportedProgram from torch.nn.attention import SDPBackend +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer from torchao.utils import unwrap_tensor_subclass FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index d7b8b3a92b1..985ff10a396 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -16,8 +16,8 @@ XNNPACKQuantizer, ) -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer +from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -154,7 +154,7 @@ def get_qnn_quantizer( QnnQuantizer, QuantDtype, ) - from torch.ao.quantization.observer import MinMaxObserver + from torchao.quantization.pt2e.observer import MinMaxObserver except ImportError: raise ImportError(