From 5f9da898ddb15c3c2d2f4a76e9cc676ccd7517fb Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Sat, 17 May 2025 15:39:09 -0700 Subject: [PATCH] Migrate ExecuTorch's use of pt2e from torch.ao to torchao (#10294) Summary: Most code related to PT2E quantization is migrating from torch.ao.quantization to torchao.quantization.pt2e. torchao.quantization.pt2e contains an exact copy of PT2E code in torch.ao.quantization. The torchao pin in ExecuTorch has already been bumped pick up these changes. Pull Request resolved: https://github.com/pytorch/executorch/pull/10294 Reviewed By: SS-JIA Differential Revision: D74694311 Pulled By: metascroy --- .lintrunner.toml | 28 ++++++++ .mypy.ini | 3 + .../coreml/test/test_coreml_quantizer.py | 4 +- backends/arm/quantizer/arm_quantizer.py | 23 ++++--- backends/arm/quantizer/arm_quantizer_utils.py | 4 +- .../arm/quantizer/quantization_annotator.py | 15 +++-- backends/arm/quantizer/quantization_config.py | 4 +- backends/arm/test/ops/test_add.py | 4 +- backends/arm/test/ops/test_sigmoid_16bit.py | 4 +- backends/arm/test/ops/test_sigmoid_32bit.py | 4 +- backends/cadence/aot/compiler.py | 2 +- backends/cadence/aot/quantizer/TARGETS | 1 + backends/cadence/aot/quantizer/patterns.py | 2 +- backends/cadence/aot/quantizer/quantizer.py | 6 +- backends/cadence/aot/quantizer/utils.py | 2 +- .../cortex_m/test/test_replace_quant_nodes.py | 10 +-- .../permute_memory_formats_pass.py | 2 +- backends/example/example_operators/utils.py | 2 +- backends/example/example_partitioner.py | 2 +- backends/example/example_quantizer.py | 9 ++- backends/example/test_example_delegate.py | 4 +- backends/mediatek/quantizer/annotator.py | 18 ++--- backends/mediatek/quantizer/qconfig.py | 6 +- backends/mediatek/quantizer/quantizer.py | 2 +- backends/nxp/quantizer/neutron_quantizer.py | 6 +- backends/nxp/quantizer/patterns.py | 2 +- backends/nxp/quantizer/utils.py | 2 +- backends/nxp/tests/executorch_pipeline.py | 2 +- backends/nxp/tests/test_quantizer.py | 2 +- backends/openvino/quantizer/quantizer.py | 4 +- backends/openvino/scripts/openvino_build.sh | 3 + .../qualcomm/_passes/annotate_quant_attrs.py | 4 +- backends/qualcomm/_passes/qnn_pass_manager.py | 4 +- backends/qualcomm/builders/node_visitor.py | 6 +- backends/qualcomm/partition/utils.py | 4 +- backends/qualcomm/quantizer/annotators.py | 56 ++++++++-------- .../qualcomm/quantizer/custom_annotation.py | 6 +- .../observers/per_block_param_observer.py | 4 +- .../observers/per_channel_param_observer.py | 2 +- backends/qualcomm/quantizer/qconfig.py | 11 ++-- backends/qualcomm/quantizer/quantizer.py | 5 +- backends/qualcomm/tests/utils.py | 14 ++-- backends/qualcomm/utils/utils.py | 2 +- .../duplicate_dynamic_quant_chain.py | 9 +-- .../test_duplicate_dynamic_quant_chain.py | 2 +- backends/vulkan/quantizer/vulkan_quantizer.py | 4 +- backends/vulkan/test/test_vulkan_delegate.py | 6 +- backends/vulkan/test/test_vulkan_passes.py | 4 +- .../xnnpack/quantizer/xnnpack_quantizer.py | 16 ++--- .../quantizer/xnnpack_quantizer_utils.py | 36 +++++----- .../test/ops/test_check_quant_params.py | 2 +- .../test/quantizer/test_pt2e_quantization.py | 65 ++++++++++++------- .../test/quantizer/test_representation.py | 4 +- .../test/quantizer/test_xnnpack_quantizer.py | 11 ++-- backends/xnnpack/test/test_xnnpack_utils.py | 18 ++--- backends/xnnpack/test/tester/tester.py | 4 +- docs/source/backends-coreml.md | 2 +- docs/source/backends-xnnpack.md | 10 +-- docs/source/llm/getting-started.md | 2 +- .../tutorial-xnnpack-delegate-lowering.md | 2 +- .../export-to-executorch-tutorial.py | 2 +- examples/arm/aot_arm_compiler.py | 4 +- examples/arm/ethos_u_minimal_example.ipynb | 2 +- .../mediatek/aot_utils/oss_utils/utils.py | 2 +- .../mediatek/model_export_scripts/llama.py | 2 +- examples/models/moshi/mimi/test_mimi.py | 2 +- .../models/phi-3-mini/export_phi-3-mini.py | 2 +- examples/qualcomm/oss_scripts/llama/llama.py | 4 +- examples/qualcomm/oss_scripts/moshi/mimi.py | 2 +- examples/qualcomm/scripts/export_example.py | 2 +- examples/qualcomm/utils.py | 15 +++-- examples/xnnpack/quantization/example.py | 2 +- examples/xnnpack/quantization/utils.py | 2 +- .../test/demos/test_xnnpack_qnnpack.py | 8 +-- exir/tests/test_memory_planning.py | 8 +-- exir/tests/test_passes.py | 8 +-- exir/tests/test_quantization.py | 9 ++- exir/tests/test_quantize_io_pass.py | 2 +- export/export.py | 4 +- extension/llm/export/builder.py | 6 +- extension/llm/export/quantizer_lib.py | 6 +- 81 files changed, 324 insertions(+), 267 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index 2835af1bf92..a26e690c3c0 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -378,3 +378,31 @@ command = [ '--', '@{{PATHSFILE}}', ] + +[[linter]] +code = "TORCH_AO_IMPORT" +include_patterns = ["**/*.py"] +exclude_patterns = [ + "third-party/**", +] + +command = [ + "python3", + "-m", + "lintrunner_adapters", + "run", + "grep_linter", + "--pattern=\\bfrom torch\\.ao\\.quantization\\.(?:quantizer|observer|quantize_pt2e|pt2e)(?:\\.[A-Za-z0-9_]+)*\\b", + "--linter-name=TorchAOImport", + "--error-name=Prohibited torch.ao.quantization import", + """--error-description=\ + Imports from torch.ao.quantization are not allowed. \ + Please import from torchao.quantization.pt2e instead.\n \ + * torchao.quantization.pt2e (includes all the utils, including observers, fake quants etc.) \n \ + * torchao.quantization.pt2e.quantizer (quantizer related objects and utils) \n \ + * torchao.quantization.pt2e.quantize_pt2e (prepare_pt2e, prepare_qat_pt2e, convert_pt2e) \n\n \ + If you need something from torch.ao.quantization, you can add your file to an exclude_patterns for TORCH_AO_IMPORT in .lintrunner.toml. \ + """, + "--", + "@{{PATHSFILE}}", +] diff --git a/.mypy.ini b/.mypy.ini index 5ee07ddb2bf..cd14cbac7ea 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -97,3 +97,6 @@ ignore_missing_imports = True [mypy-zstd] ignore_missing_imports = True + +[mypy-torchao.*] +follow_untyped_imports = True diff --git a/backends/apple/coreml/test/test_coreml_quantizer.py b/backends/apple/coreml/test/test_coreml_quantizer.py index db75631dbc8..d5754328796 100644 --- a/backends/apple/coreml/test/test_coreml_quantizer.py +++ b/backends/apple/coreml/test/test_coreml_quantizer.py @@ -15,12 +15,12 @@ ) from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer -from torch.ao.quantization.quantize_pt2e import ( +from torch.export import export_for_training +from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, prepare_qat_pt2e, ) -from torch.export import export_for_training class TestCoreMLQuantizer: diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 8b052116ed8..0d4c7d91ae8 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -30,25 +30,24 @@ is_vgf, ) # usort: skip from executorch.exir.backend.compile_spec_schema import CompileSpec -from torch.ao.quantization.fake_quantize import ( +from torch.fx import GraphModule, Node +from torchao.quantization.pt2e import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, -) -from torch.ao.quantization.observer import ( HistogramObserver, MinMaxObserver, MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, + ObserverOrFakeQuantizeConstructor, PerChannelMinMaxObserver, PlaceholderObserver, ) -from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, +from torchao.quantization.pt2e.quantizer import ( + annotate_input_qspec_map, + annotate_output_qspec, + QuantizationSpec, + Quantizer, ) -from torch.fx import GraphModule, Node __all__ = [ "TOSAQuantizer", @@ -97,7 +96,7 @@ def get_symmetric_quantization_config( weight_qscheme = ( torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric ) - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = ( MinMaxObserver ) if is_qat: @@ -337,14 +336,14 @@ def _annotate_io( if is_annotated(node): continue if node.op == "placeholder" and len(node.users) > 0: - _annotate_output_qspec( + annotate_output_qspec( node, quantization_config.get_output_act_qspec(), ) mark_node_as_annotated(node) if node.op == "output": parent = node.all_input_nodes[0] - _annotate_input_qspec_map( + annotate_input_qspec_map( node, parent, quantization_config.get_input_act_qspec() ) mark_node_as_annotated(node) diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 0ce11b620a6..d6eb72f1148 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -15,10 +15,10 @@ import torch from torch._subclasses import FakeTensor - -from torch.ao.quantization.quantizer import QuantizationAnnotation from torch.fx import GraphModule, Node +from torchao.quantization.pt2e.quantizer import QuantizationAnnotation + def is_annotated(node: Node) -> bool: """Given a node return whether the node is annotated.""" diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 55c2ca21e1b..8d04227e1a7 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -12,12 +12,13 @@ import torch.fx from executorch.backends.arm.quantizer import QuantizationConfig from executorch.backends.arm.tosa_utils import get_node_debug_info -from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) from torch.fx import Node +from torchao.quantization.pt2e.quantizer import ( + annotate_input_qspec_map, + annotate_output_qspec, + QuantizationSpecBase, + SharedQuantizationSpec, +) from .arm_quantizer_utils import ( is_annotated, @@ -118,7 +119,7 @@ def _annotate_input(node: Node, quant_property: _QuantProperty): strict=True, ): assert isinstance(n_arg, Node) - _annotate_input_qspec_map(node, n_arg, qspec) + annotate_input_qspec_map(node, n_arg, qspec) if quant_property.mark_annotated: mark_node_as_annotated(n_arg) # type: ignore[attr-defined] @@ -129,7 +130,7 @@ def _annotate_output(node: Node, quant_property: _QuantProperty): assert not quant_property.optional assert quant_property.index == 0, "Only one output annotation supported currently" - _annotate_output_qspec(node, quant_property.qspec) + annotate_output_qspec(node, quant_property.qspec) def _match_pattern( diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index 65435ac7c63..54698d058c4 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -9,9 +9,9 @@ from dataclasses import dataclass import torch -from torch.ao.quantization import ObserverOrFakeQuantize +from torchao.quantization.pt2e import ObserverOrFakeQuantize -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, FixedQParamsQuantizationSpec, QuantizationSpec, diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 76d4950be6d..e3efec02253 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -19,8 +19,8 @@ ) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.xnnpack.test.tester import Quantize -from torch.ao.quantization.observer import HistogramObserver -from torch.ao.quantization.quantizer import QuantizationSpec +from torchao.quantization.pt2e.observer import HistogramObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec aten_op = "torch.ops.aten.add.Tensor" exir_op = "executorch_exir_dialects_edge__ops_aten_add_Tensor" diff --git a/backends/arm/test/ops/test_sigmoid_16bit.py b/backends/arm/test/ops/test_sigmoid_16bit.py index ff0fe9cc4c1..5d32bdaf279 100644 --- a/backends/arm/test/ops/test_sigmoid_16bit.py +++ b/backends/arm/test/ops/test_sigmoid_16bit.py @@ -18,8 +18,8 @@ ) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.xnnpack.test.tester import Quantize -from torch.ao.quantization.observer import HistogramObserver -from torch.ao.quantization.quantizer import QuantizationSpec +from torchao.quantization.pt2e.observer import HistogramObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec def _get_16_bit_quant_config(): diff --git a/backends/arm/test/ops/test_sigmoid_32bit.py b/backends/arm/test/ops/test_sigmoid_32bit.py index 4edfdd6c23e..ad44fb52f6d 100644 --- a/backends/arm/test/ops/test_sigmoid_32bit.py +++ b/backends/arm/test/ops/test_sigmoid_32bit.py @@ -14,8 +14,8 @@ ) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.xnnpack.test.tester import Quantize -from torch.ao.quantization.observer import HistogramObserver -from torch.ao.quantization.quantizer import QuantizationSpec +from torchao.quantization.pt2e.observer import HistogramObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec def _get_16_bit_quant_config(): diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 594c4189b3a..438f07ba15f 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -37,9 +37,9 @@ from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass from executorch.exir.program._program import to_edge_with_preserved_ops from torch._inductor.decomposition import remove_decompositions -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export.exported_program import ExportedProgram +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from .passes import get_cadence_passes diff --git a/backends/cadence/aot/quantizer/TARGETS b/backends/cadence/aot/quantizer/TARGETS index 75eab631dd4..f2d6e5572b7 100644 --- a/backends/cadence/aot/quantizer/TARGETS +++ b/backends/cadence/aot/quantizer/TARGETS @@ -21,6 +21,7 @@ python_library( deps = [ ":utils", "//caffe2:torch", + "//pytorch/ao:torchao", ], ) diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 66f6772d942..cd6a7287793 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -15,7 +15,7 @@ from torch import fx from torch._ops import OpOverload -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, SharedQuantizationSpec, ) diff --git a/backends/cadence/aot/quantizer/quantizer.py b/backends/cadence/aot/quantizer/quantizer.py index 761b2bf8d31..3fbe1bcc0fd 100644 --- a/backends/cadence/aot/quantizer/quantizer.py +++ b/backends/cadence/aot/quantizer/quantizer.py @@ -38,9 +38,9 @@ from torch import fx -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver -from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer +from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver +from torchao.quantization.pt2e.quantizer import DerivedQuantizationSpec, Quantizer +from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer act_qspec_asym8s = QuantizationSpec( diff --git a/backends/cadence/aot/quantizer/utils.py b/backends/cadence/aot/quantizer/utils.py index 0f9c9399780..fad5ca41e22 100644 --- a/backends/cadence/aot/quantizer/utils.py +++ b/backends/cadence/aot/quantizer/utils.py @@ -14,13 +14,13 @@ import torch from torch import fx from torch._ops import OpOverload -from torch.ao.quantization import ObserverOrFakeQuantize from torch.fx import GraphModule from torch.fx.passes.utils.source_matcher_utils import ( check_subgraphs_connected, SourcePartition, ) +from torchao.quantization.pt2e import ObserverOrFakeQuantize def quantize_tensor_multiplier( diff --git a/backends/cortex_m/test/test_replace_quant_nodes.py b/backends/cortex_m/test/test_replace_quant_nodes.py index f993b42c920..54f5142add8 100644 --- a/backends/cortex_m/test/test_replace_quant_nodes.py +++ b/backends/cortex_m/test/test_replace_quant_nodes.py @@ -16,15 +16,15 @@ ReplaceQuantNodesPass, ) from executorch.exir.dialects._ops import ops as exir_ops -from torch.ao.quantization.observer import HistogramObserver -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.quantizer import ( +from torch.export import export, export_for_training +from torch.fx import GraphModule +from torchao.quantization.pt2e.observer import HistogramObserver +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import ( QuantizationAnnotation, QuantizationSpec, Quantizer, ) -from torch.export import export, export_for_training -from torch.fx import GraphModule @dataclass(eq=True, frozen=True) diff --git a/backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py b/backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py index 8e857f32376..4f78c59927a 100644 --- a/backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py +++ b/backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py @@ -11,7 +11,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dim_order_utils import get_dim_order from executorch.exir.pass_base import ExportPass, PassResult -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions +from torchao.quantization.pt2e import find_sequential_partitions class PermuteMemoryFormatsPass(ExportPass): diff --git a/backends/example/example_operators/utils.py b/backends/example/example_operators/utils.py index 7dca2a3be6a..d9b3a436840 100644 --- a/backends/example/example_operators/utils.py +++ b/backends/example/example_operators/utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation +from torchao.quantization.pt2e.quantizer.quantizer import QuantizationAnnotation def _nodes_are_annotated(node_list): diff --git a/backends/example/example_partitioner.py b/backends/example/example_partitioner.py index 5e9102e999b..7bfe783a4f3 100644 --- a/backends/example/example_partitioner.py +++ b/backends/example/example_partitioner.py @@ -19,9 +19,9 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.graph_module import get_control_flow_submodules -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions from torch.export import ExportedProgram from torch.fx.passes.operator_support import OperatorSupportBase +from torchao.quantization.pt2e import find_sequential_partitions @final diff --git a/backends/example/example_quantizer.py b/backends/example/example_quantizer.py index 74a0057ba4a..c5a3e179695 100644 --- a/backends/example/example_quantizer.py +++ b/backends/example/example_quantizer.py @@ -11,9 +11,12 @@ from executorch.backends.example.example_operators.ops import module_to_annotator from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import OperatorConfig from torch import fx -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver -from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer +from torchao.quantization.pt2e import ( + find_sequential_partitions, + HistogramObserver, + MinMaxObserver, +) +from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer def get_uint8_tensor_spec(observer_or_fake_quant_ctr): diff --git a/backends/example/test_example_delegate.py b/backends/example/test_example_delegate.py index a382273af07..bc6ad4d7e4c 100644 --- a/backends/example/test_example_delegate.py +++ b/backends/example/test_example_delegate.py @@ -17,10 +17,10 @@ DuplicateDequantNodePass, ) from executorch.exir.delegate import executorch_call_delegate - -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e + from torchvision.models.quantization import mobilenet_v2 diff --git a/backends/mediatek/quantizer/annotator.py b/backends/mediatek/quantizer/annotator.py index d250b774af8..efdde09be88 100644 --- a/backends/mediatek/quantizer/annotator.py +++ b/backends/mediatek/quantizer/annotator.py @@ -10,18 +10,18 @@ from torch._ops import OpOverload from torch._subclasses import FakeTensor -from torch.ao.quantization.quantizer import QuantizationAnnotation -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, -) - from torch.export import export_for_training from torch.fx import Graph, Node from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( SubgraphMatcherWithNameNodeMap, ) +from torchao.quantization.pt2e.quantizer import ( + annotate_input_qspec_map, + annotate_output_qspec as _annotate_output_qspec, + QuantizationAnnotation, +) + from .qconfig import QuantizationConfig @@ -108,7 +108,7 @@ def _annotate_fused_activation_pattern( torch.ops.aten.linear.default, ]: weight_node = producer_node.args[1] - _annotate_input_qspec_map( + annotate_input_qspec_map( producer_node, weight_node, quant_config.weight, @@ -201,7 +201,7 @@ def annotate_affine_ops(node: Node, quant_config: QuantizationConfig) -> None: return weight_node = node.args[1] - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quant_config.weight, @@ -260,5 +260,5 @@ def annotate_embedding_op(node: Node, quant_config: QuantizationConfig) -> None: return wgt_node = node.args[0] - _annotate_input_qspec_map(node, wgt_node, quant_config.activation) + annotate_input_qspec_map(node, wgt_node, quant_config.activation) _mark_as_annotated([node]) diff --git a/backends/mediatek/quantizer/qconfig.py b/backends/mediatek/quantizer/qconfig.py index e16f5e936cb..d9f105cd0d0 100644 --- a/backends/mediatek/quantizer/qconfig.py +++ b/backends/mediatek/quantizer/qconfig.py @@ -10,9 +10,9 @@ import torch -from torch.ao.quantization.fake_quantize import FakeQuantize -from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver -from torch.ao.quantization.quantizer import QuantizationSpec +from torchao.quantization.pt2e.fake_quantize import FakeQuantize +from torchao.quantization.pt2e.observer import MinMaxObserver, PerChannelMinMaxObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec @unique diff --git a/backends/mediatek/quantizer/quantizer.py b/backends/mediatek/quantizer/quantizer.py index 4e78d6dff1a..f9babdec997 100644 --- a/backends/mediatek/quantizer/quantizer.py +++ b/backends/mediatek/quantizer/quantizer.py @@ -4,8 +4,8 @@ # except in compliance with the License. See the license file in the root # directory of this source tree for more details. -from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule +from torchao.quantization.pt2e.quantizer import Quantizer from .._passes.decompose_scaled_dot_product_attention import ( DecomposeScaledDotProductAttention, diff --git a/backends/nxp/quantizer/neutron_quantizer.py b/backends/nxp/quantizer/neutron_quantizer.py index eff7f513cb9..8a64170b632 100644 --- a/backends/nxp/quantizer/neutron_quantizer.py +++ b/backends/nxp/quantizer/neutron_quantizer.py @@ -35,9 +35,9 @@ QuantizationSpec, ) from torch import fx -from torch.ao.quantization.observer import HistogramObserver, MinMaxObserver -from torch.ao.quantization.quantizer import DerivedQuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer +from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver +from torchao.quantization.pt2e.quantizer import DerivedQuantizationSpec, Quantizer +from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer class NeutronAtenQuantizer(Quantizer): diff --git a/backends/nxp/quantizer/patterns.py b/backends/nxp/quantizer/patterns.py index 6797447c50c..b71f0621002 100644 --- a/backends/nxp/quantizer/patterns.py +++ b/backends/nxp/quantizer/patterns.py @@ -14,7 +14,7 @@ from executorch.backends.nxp.quantizer.utils import get_bias_qparams from torch import fx from torch._ops import OpOverload -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, FixedQParamsQuantizationSpec, SharedQuantizationSpec, diff --git a/backends/nxp/quantizer/utils.py b/backends/nxp/quantizer/utils.py index 1effcdff25a..1b941f6e632 100644 --- a/backends/nxp/quantizer/utils.py +++ b/backends/nxp/quantizer/utils.py @@ -14,11 +14,11 @@ import torch from torch import fx from torch._ops import OpOverload -from torch.ao.quantization import ObserverOrFakeQuantize from torch.fx.passes.utils.source_matcher_utils import ( check_subgraphs_connected, SourcePartition, ) +from torchao.quantization.pt2e import ObserverOrFakeQuantize def is_annotated(nodes: List[fx.Node]) -> bool: diff --git a/backends/nxp/tests/executorch_pipeline.py b/backends/nxp/tests/executorch_pipeline.py index 6c452b99baf..fe1cdeaf751 100644 --- a/backends/nxp/tests/executorch_pipeline.py +++ b/backends/nxp/tests/executorch_pipeline.py @@ -20,7 +20,7 @@ to_edge_transform_and_lower, ) from torch import nn -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]): diff --git a/backends/nxp/tests/test_quantizer.py b/backends/nxp/tests/test_quantizer.py index 868a94059b5..dd1b691a18f 100644 --- a/backends/nxp/tests/test_quantizer.py +++ b/backends/nxp/tests/test_quantizer.py @@ -8,7 +8,7 @@ import executorch.backends.nxp.tests.models as models import torch from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def _get_target_name(node): diff --git a/backends/openvino/quantizer/quantizer.py b/backends/openvino/quantizer/quantizer.py index 5532235f573..134ed5cb4ac 100644 --- a/backends/openvino/quantizer/quantizer.py +++ b/backends/openvino/quantizer/quantizer.py @@ -17,12 +17,12 @@ import torch.fx from nncf.common.graph.graph import NNCFGraph # type: ignore[import-untyped] -from torch.ao.quantization.observer import ( +from torchao.quantization.pt2e.observer import ( HistogramObserver, PerChannelMinMaxObserver, UniformQuantizationObserverBase, ) -from torch.ao.quantization.quantizer.quantizer import ( +from torchao.quantization.pt2e.quantizer.quantizer import ( EdgeOrNode, QuantizationAnnotation, QuantizationSpec, diff --git a/backends/openvino/scripts/openvino_build.sh b/backends/openvino/scripts/openvino_build.sh index 83ffd7542f3..e36501c941a 100755 --- a/backends/openvino/scripts/openvino_build.sh +++ b/backends/openvino/scripts/openvino_build.sh @@ -59,6 +59,9 @@ main() { # Build the package pip install . --no-build-isolation + # Install torchao + pip install third-party/ao + else echo "Error: Argument is not valid: $build_type" exit 1 # Exit the script with an error code diff --git a/backends/qualcomm/_passes/annotate_quant_attrs.py b/backends/qualcomm/_passes/annotate_quant_attrs.py index b4f14fc28cd..5a41f04ff7b 100644 --- a/backends/qualcomm/_passes/annotate_quant_attrs.py +++ b/backends/qualcomm/_passes/annotate_quant_attrs.py @@ -112,9 +112,9 @@ def _dequant_fold_params(self, n, quant_attrs, param): offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis) param = param.sub(offsets).mul(scales).to(torch.float32).contiguous() elif quant_attrs[QCOM_ENCODING] in [ - exir_ops.edge.pt2e_quant.dequantize_affine.default + exir_ops.edge.torchao.dequantize_affine.default ]: - param = torch.ops.pt2e_quant.dequantize_affine( + param = torch.ops.torchao.dequantize_affine( param, block_size=quant_attrs[QCOM_BLOCK_SIZE], scale=quant_attrs[QCOM_SCALE], diff --git a/backends/qualcomm/_passes/qnn_pass_manager.py b/backends/qualcomm/_passes/qnn_pass_manager.py index c98f27db120..c4b730dc5b2 100644 --- a/backends/qualcomm/_passes/qnn_pass_manager.py +++ b/backends/qualcomm/_passes/qnn_pass_manager.py @@ -131,8 +131,8 @@ def get_to_edge_transform_passes( from executorch.backends.qualcomm._passes import utils from executorch.exir.dialects._ops import ops as exir_ops - utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default) - utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default) + utils.q_ops.add(exir_ops.edge.torchao.quantize_affine.default) + utils.dq_ops.add(exir_ops.edge.torchao.dequantize_affine.default) passes_job = ( passes_job if passes_job is not None else get_capture_program_passes() diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index 7965a30caea..e99d6b2e620 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -242,8 +242,8 @@ def get_quant_encoding_conf( ) # TODO: refactor this when target could be correctly detected per_block_encoding = { - exir_ops.edge.pt2e_quant.quantize_affine.default, - exir_ops.edge.pt2e_quant.dequantize_affine.default, + exir_ops.edge.torchao.quantize_affine.default, + exir_ops.edge.torchao.dequantize_affine.default, } if quant_attrs[QCOM_ENCODING] in per_block_encoding: return self.make_qnn_per_block_config(node, quant_attrs) @@ -271,7 +271,7 @@ def get_quant_tensor_value( axis_order.index(x) for x in range(len(axis_order)) ) tensor = tensor.permute(origin_order) - tensor = torch.ops.pt2e_quant.quantize_affine( + tensor = torch.ops.torchao.quantize_affine( tensor, block_size=quant_attrs[QCOM_BLOCK_SIZE], scale=quant_attrs[QCOM_SCALE], diff --git a/backends/qualcomm/partition/utils.py b/backends/qualcomm/partition/utils.py index 816d1ac1d9b..05bbd1ff970 100644 --- a/backends/qualcomm/partition/utils.py +++ b/backends/qualcomm/partition/utils.py @@ -57,7 +57,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]: torch.ops.aten.upsample_bicubic2d.vec, # This request is ignored because it is in a blocklist. Refer to exir/program/_program.py torch.ops.aten.unbind.int, - torch.ops.pt2e_quant.quantize_affine.default, - torch.ops.pt2e_quant.dequantize_affine.default, + torch.ops.torchao.quantize_affine.default, + torch.ops.torchao.dequantize_affine.default, ] return do_not_decompose diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index 469a801feeb..545039555cf 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -12,20 +12,20 @@ from torch._ops import OpOverload from torch._subclasses import FakeTensor -from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize +from torch.fx import Node +from torchao.quantization.pt2e.fake_quantize import FixedQParamsFakeQuantize -from torch.ao.quantization.observer import FixedQParamsObserver -from torch.ao.quantization.quantizer import ( +from torchao.quantization.pt2e.observer import FixedQParamsObserver +from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, +from torchao.quantization.pt2e.quantizer.utils import ( + annotate_input_qspec_map, + annotate_output_qspec, ) -from torch.fx import Node from .qconfig import ( get_16a16w_qnn_ptq_config, @@ -618,19 +618,19 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No return # TODO current only support 16a16w - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, quantization_config.input_activation, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quantization_config.input_activation, ) nodes_to_mark_annotated = [node] - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) @@ -819,25 +819,25 @@ def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) -> if _is_annotated([node]): return - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, quantization_config.input_activation, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quantization_config.weight, ) nodes_to_mark_annotated = [node, weight_node] if bias_node: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias_node, quantization_config.bias, ) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) @@ -1002,12 +1002,12 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None if _is_annotated([node]): return - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, quantization_config.input_activation, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, quantization_config.weight, @@ -1018,9 +1018,9 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None bias_config = quantization_config.bias(node) else: bias_config = quantization_config.bias - _annotate_input_qspec_map(node, bias_node, bias_config) + annotate_input_qspec_map(node, bias_node, bias_config) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack. @@ -1038,14 +1038,14 @@ def annotate_batch_and_instance_norm( return annotated_args = [act] - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act, quantization_config.input_activation, ) # QNN requires uint8 instead of int8 in 'weight' config if weight is not None: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight, quantization_config.input_activation, @@ -1053,14 +1053,14 @@ def annotate_batch_and_instance_norm( annotated_args.append(weight) if bias is not None: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias, quantization_config.bias, ) annotated_args.append(bias) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated([node, *annotated_args]) @@ -1070,7 +1070,7 @@ def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> Non return if _is_float_tensor(node): - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated([node]) @@ -1086,32 +1086,32 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> return input_act_qspec = quantization_config.input_activation - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, input_act_qspec, ) if input_act_qspec.dtype == torch.int32: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, get_16a16w_qnn_ptq_config().weight, ) else: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, input_act_qspec, ) nodes_to_mark_annotated = [node, weight_node] if bias_node: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias_node, quantization_config.bias, ) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, quantization_config.output_activation) + annotate_output_qspec(node, quantization_config.output_activation) _mark_nodes_as_annotated(nodes_to_mark_annotated) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index bda91609f1c..2771645c3e0 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -17,13 +17,13 @@ QuantizationConfig, ) from executorch.exir.dialects._ops import ops as exir_ops -from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver -from torch.ao.quantization.quantizer import ( +from torch.fx import Node +from torchao.quantization.pt2e.observer import FixedQParamsObserver, MinMaxObserver +from torchao.quantization.pt2e.quantizer import ( QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) -from torch.fx import Node def annotate_mimi_decoder(gm: torch.fx.GraphModule): diff --git a/backends/qualcomm/quantizer/observers/per_block_param_observer.py b/backends/qualcomm/quantizer/observers/per_block_param_observer.py index e60f15c6d9c..982f3fdbb65 100644 --- a/backends/qualcomm/quantizer/observers/per_block_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_block_param_observer.py @@ -7,12 +7,12 @@ from typing import Tuple import torch -from torch.ao.quantization.observer import MappingType, PerBlock -from torch.ao.quantization.pt2e._affine_quantization import ( +from torchao.quantization.pt2e._affine_quantization import ( _get_reduction_params, AffineQuantizedMinMaxObserver, choose_qparams_affine_with_min_max, ) +from torchao.quantization.pt2e.observer import MappingType, PerBlock class PerBlockParamObserver(AffineQuantizedMinMaxObserver): diff --git a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py index 3c04e620308..cf57c94b72e 100644 --- a/backends/qualcomm/quantizer/observers/per_channel_param_observer.py +++ b/backends/qualcomm/quantizer/observers/per_channel_param_observer.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch -from torch.ao.quantization.observer import UniformQuantizationObserverBase +from torchao.quantization.pt2e.observer import UniformQuantizationObserverBase # TODO move to torch/ao/quantization/observer.py. diff --git a/backends/qualcomm/quantizer/qconfig.py b/backends/qualcomm/quantizer/qconfig.py index 67968363eb6..e0a838cc32e 100644 --- a/backends/qualcomm/quantizer/qconfig.py +++ b/backends/qualcomm/quantizer/qconfig.py @@ -7,18 +7,21 @@ PerBlockParamObserver, ) from torch import Tensor -from torch.ao.quantization.fake_quantize import ( +from torch.fx import Node +from torchao.quantization.pt2e.fake_quantize import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, ) -from torch.ao.quantization.observer import ( +from torchao.quantization.pt2e.observer import ( MinMaxObserver, MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver, PerChannelMinMaxObserver, ) -from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec -from torch.fx import Node +from torchao.quantization.pt2e.quantizer import ( + DerivedQuantizationSpec, + QuantizationSpec, +) @dataclass(eq=True) diff --git a/backends/qualcomm/quantizer/quantizer.py b/backends/qualcomm/quantizer/quantizer.py index 8e65607dd84..4a1bca70add 100644 --- a/backends/qualcomm/quantizer/quantizer.py +++ b/backends/qualcomm/quantizer/quantizer.py @@ -9,11 +9,12 @@ from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple import torch +import torchao from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager from torch._ops import OpOverload -from torch.ao.quantization.quantizer import Quantizer from torch.fx import GraphModule +from torchao.quantization.pt2e.quantizer import Quantizer from .annotators import OP_ANNOTATOR @@ -131,7 +132,7 @@ class ModuleQConfig: is_conv_per_channel: bool = False is_linear_per_channel: bool = False act_observer: Optional[ - torch.ao.quantization.observer.UniformQuantizationObserverBase + torchao.quantization.pt2e.observer.UniformQuantizationObserverBase ] = None def __post_init__(self): diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index 695c846de05..bafa753f982 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -14,6 +14,8 @@ import numpy as np import torch +import torchao + from executorch import exir from executorch.backends.qualcomm.qnn_preprocess import QnnBackend from executorch.backends.qualcomm.quantizer.quantizer import ModuleQConfig, QuantDtype @@ -43,12 +45,12 @@ from executorch.exir.pass_base import ExportPass from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from executorch.exir.program import ExecutorchProgram, ExecutorchProgramManager -from torch.ao.quantization.quantize_pt2e import ( +from torch.fx.passes.infra.pass_base import PassResult +from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, prepare_qat_pt2e, ) -from torch.fx.passes.infra.pass_base import PassResult def generate_context_binary( @@ -536,8 +538,8 @@ def get_qdq_module( torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_channel.default, torch.ops.quantized_decomposed.dequantize_per_channel.default, - torch.ops.pt2e_quant.quantize_affine.default, - torch.ops.pt2e_quant.dequantize_affine.default, + torch.ops.torchao.quantize_affine.default, + torch.ops.torchao.dequantize_affine.default, } if not bypass_check: self.assertTrue(nodes.intersection(q_and_dq)) @@ -568,7 +570,7 @@ def get_prepared_qat_module( quantizer.set_submodule_qconfig_list(submodule_qconfig_list) prepared = prepare_qat_pt2e(m, quantizer) - return torch.ao.quantization.move_exported_model_to_train(prepared) + return torchao.quantization.pt2e.move_exported_model_to_train(prepared) def get_converted_sgd_trained_module( self, @@ -583,7 +585,7 @@ def get_converted_sgd_trained_module( optimizer.zero_grad() loss.backward() optimizer.step() - return torch.ao.quantization.quantize_pt2e.convert_pt2e(prepared) + return torchao.quantization.pt2e.quantize_pt2e.convert_pt2e(prepared) def split_graph(self, division: int): class SplitGraph(ExportPass): diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 3653cd3176f..c6ba6a4a972 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -607,8 +607,8 @@ def skip_annotation( from executorch.backends.qualcomm.serialization.qc_schema_serialize import ( flatbuffer_to_option, ) - from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner + from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def prepare_subgm(subgm, subgm_name): # prepare current submodule for quantization annotation diff --git a/backends/transforms/duplicate_dynamic_quant_chain.py b/backends/transforms/duplicate_dynamic_quant_chain.py index 2ca65eec45f..d7f119c4cf6 100644 --- a/backends/transforms/duplicate_dynamic_quant_chain.py +++ b/backends/transforms/duplicate_dynamic_quant_chain.py @@ -9,14 +9,11 @@ import torch -from torch.ao.quantization.pt2e.utils import ( - _filter_sym_size_users, - _is_valid_annotation, -) - from torch.fx.node import map_arg from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torchao.quantization.pt2e.quantizer import is_valid_annotation +from torchao.quantization.pt2e.utils import _filter_sym_size_users logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) @@ -129,7 +126,7 @@ def _maybe_duplicate_dynamic_quantize_chain( dq_node_users = list(dq_node.users.copy()) for user in dq_node_users: annotation = user.meta.get("quantization_annotation", None) - if not _is_valid_annotation(annotation): + if not is_valid_annotation(annotation): return with gm.graph.inserting_after(dq_node): new_node = gm.graph.node_copy(dq_node) diff --git a/backends/transforms/test/test_duplicate_dynamic_quant_chain.py b/backends/transforms/test/test_duplicate_dynamic_quant_chain.py index ab965dd347d..79bc56f8780 100644 --- a/backends/transforms/test/test_duplicate_dynamic_quant_chain.py +++ b/backends/transforms/test/test_duplicate_dynamic_quant_chain.py @@ -15,7 +15,6 @@ get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e # TODO: Move away from using torch's internal testing utils from torch.testing._internal.common_quantization import ( @@ -23,6 +22,7 @@ QuantizationTestCase, TestHelperModules, ) +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e class MyTestHelperModules: diff --git a/backends/vulkan/quantizer/vulkan_quantizer.py b/backends/vulkan/quantizer/vulkan_quantizer.py index b2f1a658040..736649a016b 100644 --- a/backends/vulkan/quantizer/vulkan_quantizer.py +++ b/backends/vulkan/quantizer/vulkan_quantizer.py @@ -18,9 +18,9 @@ propagate_annotation, QuantizationConfig, ) -from torch.ao.quantization.observer import PerChannelMinMaxObserver -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer from torch.fx import Node +from torchao.quantization.pt2e.observer import PerChannelMinMaxObserver +from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer __all__ = [ diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index b57710974e8..41866fe4a46 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -23,11 +23,11 @@ EdgeProgramManager, ExecutorchProgramManager, ) +from torch.export import Dim, export, export_for_training, ExportedProgram -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer import Quantizer -from torch.export import Dim, export, export_for_training, ExportedProgram +from torchao.quantization.pt2e.quantizer import Quantizer ctypes.CDLL("libvulkan.so.1") diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 7572ebd5a5a..ff9e2d85a96 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -17,8 +17,8 @@ format_target_name, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer import Quantizer +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer ################### ## Common Models ## diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer.py b/backends/xnnpack/quantizer/xnnpack_quantizer.py index 229b75f0ed9..66f1fb14750 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer.py @@ -17,11 +17,11 @@ propagate_annotation, QuantizationConfig, ) -from torch.ao.quantization.fake_quantize import ( +from torchao.quantization.pt2e.fake_quantize import ( FakeQuantize, FusedMovingAvgObsFakeQuantize, ) -from torch.ao.quantization.observer import ( +from torchao.quantization.pt2e.observer import ( HistogramObserver, MinMaxObserver, MovingAverageMinMaxObserver, @@ -29,13 +29,13 @@ PerChannelMinMaxObserver, PlaceholderObserver, ) -from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer -from torch.ao.quantization.quantizer.utils import _get_module_name_filter +from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer +from torchao.quantization.pt2e.quantizer.utils import get_module_name_filter if TYPE_CHECKING: - from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor from torch.fx import Node + from torchao.quantization.pt2e import ObserverOrFakeQuantizeConstructor __all__ = [ @@ -140,7 +140,7 @@ def get_symmetric_quantization_config( weight_qscheme = ( torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric ) - weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = ( + weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = ( MinMaxObserver ) if is_qat: @@ -228,7 +228,7 @@ def _get_not_module_type_or_name_filter( tp_list: list[Callable], module_name_list: list[str] ) -> Callable[[Node], bool]: module_type_filters = [_get_module_type_filter(tp) for tp in tp_list] - module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list] + module_name_list_filters = [get_module_name_filter(m) for m in module_name_list] def not_module_type_or_name_filter(n: Node) -> bool: return not any(f(n) for f in module_type_filters + module_name_list_filters) @@ -421,7 +421,7 @@ def _annotate_for_quantization_config( module_name_list = list(self.module_name_config.keys()) for module_name, config in self.module_name_config.items(): self._annotate_all_patterns( - model, config, _get_module_name_filter(module_name) + model, config, get_module_name_filter(module_name) ) tp_list = list(self.module_type_config.keys()) diff --git a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py index 4b961bef81d..ff4114cfdd4 100644 --- a/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py +++ b/backends/xnnpack/quantizer/xnnpack_quantizer_utils.py @@ -9,26 +9,26 @@ from executorch.backends.xnnpack.utils.utils import is_depthwise_conv from torch._subclasses import FakeTensor from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix -from torch.ao.quantization.pt2e.export_utils import _WrapperModule -from torch.ao.quantization.pt2e.utils import ( - _get_aten_graph_module_for_pattern, - _is_conv_node, - _is_conv_transpose_node, +from torch.fx import Node +from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( + SubgraphMatcherWithNameNodeMap, ) -from torch.ao.quantization.quantizer import ( +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions +from torchao.quantization.pt2e.export_utils import WrapperModule +from torchao.quantization.pt2e.quantizer import ( QuantizationAnnotation, QuantizationSpec, SharedQuantizationSpec, ) -from torch.ao.quantization.quantizer.utils import ( - _annotate_input_qspec_map, - _annotate_output_qspec, +from torchao.quantization.pt2e.quantizer.utils import ( + annotate_input_qspec_map, + annotate_output_qspec, ) -from torch.fx import Node -from torch.fx.passes.utils.matcher_with_name_node_map_utils import ( - SubgraphMatcherWithNameNodeMap, +from torchao.quantization.pt2e.utils import ( + _get_aten_graph_module_for_pattern, + _is_conv_node, + _is_conv_transpose_node, ) -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions __all__ = [ "OperatorConfig", @@ -204,25 +204,25 @@ def _annotate_linear( bias_node = node.args[2] if _is_annotated([node]) is False: # type: ignore[list-item] - _annotate_input_qspec_map( + annotate_input_qspec_map( node, act_node, input_act_qspec, ) - _annotate_input_qspec_map( + annotate_input_qspec_map( node, weight_node, weight_qspec, ) nodes_to_mark_annotated = [node, weight_node] if bias_node: - _annotate_input_qspec_map( + annotate_input_qspec_map( node, bias_node, bias_qspec, ) nodes_to_mark_annotated.append(bias_node) - _annotate_output_qspec(node, output_act_qspec) + annotate_output_qspec(node, output_act_qspec) _mark_nodes_as_annotated(nodes_to_mark_annotated) annotated_partitions.append(nodes_to_mark_annotated) @@ -572,7 +572,7 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv): "output": output, } - return _WrapperModule(_conv_bn) + return WrapperModule(_conv_bn) # Needed for matching, otherwise the matches gets filtered out due to unused # nodes returned by batch norm diff --git a/backends/xnnpack/test/ops/test_check_quant_params.py b/backends/xnnpack/test/ops/test_check_quant_params.py index d05b1fce540..8be59aab50e 100644 --- a/backends/xnnpack/test/ops/test_check_quant_params.py +++ b/backends/xnnpack/test/ops/test_check_quant_params.py @@ -9,8 +9,8 @@ ) from executorch.backends.xnnpack.utils.utils import get_param_tensor from executorch.exir import to_edge_transform_and_lower -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export_for_training +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e class TestCheckQuantParams(unittest.TestCase): diff --git a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py index 4243441118e..2f1d81e95b5 100644 --- a/backends/xnnpack/test/quantizer/test_pt2e_quantization.py +++ b/backends/xnnpack/test/quantizer/test_pt2e_quantization.py @@ -10,21 +10,12 @@ from typing import Dict, Tuple import torch +import torchao from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.ao.quantization import ( - compare_results, - CUSTOM_KEY, - default_per_channel_symmetric_qnnpack_qconfig, - extract_results_from_loggers, - generate_numeric_debug_handle, - NUMERIC_DEBUG_HANDLE_KEY, - observer, - prepare_for_propagation_comparison, -) -from torch.ao.quantization.pt2e.graph_utils import bfs_trace_with_node_process +from torch.ao.quantization import default_per_channel_symmetric_qnnpack_qconfig from torch.ao.quantization.qconfig import ( float_qparams_weight_only_qconfig, per_channel_weight_observer_range_neg_127_to_127, @@ -32,18 +23,9 @@ weight_observer_range_neg_127_to_127, ) from torch.ao.quantization.qconfig_mapping import QConfigMapping -from torch.ao.quantization.quantize_pt2e import ( - convert_pt2e, - prepare_pt2e, - prepare_qat_pt2e, -) -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer -from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, - PT2EQuantizationTestCase, TestHelperModules, ) from torch.testing._internal.common_utils import ( @@ -51,9 +33,42 @@ TemporaryFileName, TestCase, ) +from torchao.quantization.pt2e import ( + compare_results, + CUSTOM_KEY, + extract_results_from_loggers, + generate_numeric_debug_handle, + NUMERIC_DEBUG_HANDLE_KEY, + observer, + prepare_for_propagation_comparison, +) +from torchao.quantization.pt2e.graph_utils import bfs_trace_with_node_process +from torchao.quantization.pt2e.quantize_pt2e import ( + convert_pt2e, + prepare_pt2e, + prepare_qat_pt2e, +) +from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer +from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer +from torchao.testing.pt2e.utils import PT2EQuantizationTestCase class TestQuantizePT2E(PT2EQuantizationTestCase): + def _quantize(self, m, quantizer, example_inputs, is_qat: bool = False): + # resetting dynamo cache + torch._dynamo.reset() + + gm = export_for_training(m, example_inputs, strict=True).module() + assert isinstance(gm, torch.fx.GraphModule) + if is_qat: + gm = prepare_qat_pt2e(gm, quantizer) + else: + gm = prepare_pt2e(gm, quantizer) + gm(*example_inputs) + gm = convert_pt2e(gm) + return gm + def _get_pt2e_quantized_linear( self, is_per_channel: bool = False ) -> torch.fx.GraphModule: @@ -287,7 +302,7 @@ def test_embedding_conv_linear_quantization(self) -> None: [embedding_quantizer, dynamic_quantizer, static_quantizer] ) - act_affine_quant_obs = observer.PlaceholderObserver.with_args( + act_affine_quant_obs = torch.ao.quantization.PlaceholderObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=-128, @@ -404,7 +419,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: m.train() # After wrapping: does not error and swaps the ops accordingly - torch.ao.quantization.allow_exported_model_train_eval(m) # pyre-ignore[6] + torchao.quantization.pt2e.allow_exported_model_train_eval(m) # pyre-ignore[6] m.eval() _assert_ops_are_correct(m, train=False) # pyre-ignore[6] m.train() @@ -419,7 +434,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: m.train() # After prepare and after wrapping: does not error and swaps the ops accordingly - torch.ao.quantization.allow_exported_model_train_eval(m) + torchao.quantization.pt2e.allow_exported_model_train_eval(m) m.eval() _assert_ops_are_correct(m, train=False) m.train() @@ -433,7 +448,7 @@ def _assert_ops_are_correct(m: torch.fx.GraphModule, train: bool) -> None: m.train() # After convert and after wrapping: does not error and swaps the ops accordingly - torch.ao.quantization.allow_exported_model_train_eval(m) + torchao.quantization.pt2e.allow_exported_model_train_eval(m) m.eval() _assert_ops_are_correct(m, train=False) m.train() @@ -783,7 +798,7 @@ def test_extract_results_from_loggers(self) -> None: ref_results = extract_results_from_loggers(m_ref_logger) quant_results = extract_results_from_loggers(m_quant_logger) comparison_results = compare_results( - ref_results, + ref_results, # pyre-ignore[6] quant_results, # pyre-ignore[6] ) for node_summary in comparison_results.values(): diff --git a/backends/xnnpack/test/quantizer/test_representation.py b/backends/xnnpack/test/quantizer/test_representation.py index e52bbbd7ae7..817f7f9e368 100644 --- a/backends/xnnpack/test/quantizer/test_representation.py +++ b/backends/xnnpack/test/quantizer/test_representation.py @@ -8,8 +8,6 @@ XNNPACKQuantizer, ) from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401 -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer import Quantizer from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, @@ -17,6 +15,8 @@ skipIfNoQNNPACK, TestHelperModules, ) +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer @skipIfNoQNNPACK diff --git a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py index 9b0515bad58..0a317ad8822 100644 --- a/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py +++ b/backends/xnnpack/test/quantizer/test_xnnpack_quantizer.py @@ -12,7 +12,6 @@ from torch.ao.quantization import ( default_dynamic_fake_quant, default_dynamic_qconfig, - observer, QConfig, QConfigMapping, ) @@ -28,16 +27,16 @@ convert_to_reference_fx, prepare_fx, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export_for_training from torch.testing._internal.common_quantization import ( NodeSpec as ns, - PT2EQuantizationTestCase, skip_if_no_torchvision, skipIfNoQNNPACK, TestHelperModules, ) from torch.testing._internal.common_quantized import override_quantized_engine +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.testing.pt2e.utils import PT2EQuantizationTestCase @skipIfNoQNNPACK @@ -575,7 +574,7 @@ def test_dynamic_linear(self): torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } - act_affine_quant_obs = observer.PlaceholderObserver.with_args( + act_affine_quant_obs = torch.ao.quantization.PlaceholderObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=-128, @@ -621,7 +620,7 @@ def test_dynamic_linear_int4_weight(self): torch.ops.quantized_decomposed.quantize_per_channel.default: 0, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, } - act_affine_quant_obs = observer.PlaceholderObserver.with_args( + act_affine_quant_obs = torch.ao.quantization.PlaceholderObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=-128, @@ -718,7 +717,7 @@ def test_dynamic_linear_with_conv(self): torch.ops.quantized_decomposed.quantize_per_tensor.default: 0, torch.ops.quantized_decomposed.dequantize_per_tensor.default: 0, } - act_affine_quant_obs = observer.PlaceholderObserver.with_args( + act_affine_quant_obs = torch.ao.quantization.PlaceholderObserver.with_args( dtype=torch.qint8, qscheme=torch.per_tensor_affine, quant_min=-128, diff --git a/backends/xnnpack/test/test_xnnpack_utils.py b/backends/xnnpack/test/test_xnnpack_utils.py index 3ff2f0e4c1e..e6c97545d82 100644 --- a/backends/xnnpack/test/test_xnnpack_utils.py +++ b/backends/xnnpack/test/test_xnnpack_utils.py @@ -47,7 +47,6 @@ from torch.ao.quantization import ( # @manual default_per_channel_symmetric_qnnpack_qconfig, - PlaceholderObserver, QConfig, QConfigMapping, ) @@ -55,12 +54,6 @@ from torch.ao.quantization.backend_config.executorch import ( get_executorch_backend_config, ) - -from torch.ao.quantization.observer import ( - per_channel_weight_observer_range_neg_127_to_127, - # default_weight_observer, - weight_observer_range_neg_127_to_127, -) from torch.ao.quantization.qconfig_mapping import ( _get_default_qconfig_mapping_with_default_qconfig, _get_symmetric_qnnpack_qconfig_mapping, @@ -70,11 +63,18 @@ _convert_to_reference_decomposed_fx, prepare_fx, ) - -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export_for_training from torch.testing import FileCheck +from torchao.quantization.pt2e import PlaceholderObserver + +from torchao.quantization.pt2e.observer import ( + per_channel_weight_observer_range_neg_127_to_127, + # default_weight_observer, + weight_observer_range_neg_127_to_127, +) + +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def randomize_bn(num_features: int, dimensionality: int = 2) -> torch.nn.Module: diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index cbce817cf4b..fd48837bd72 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -55,11 +55,11 @@ ) from executorch.exir.program._program import _transform from torch._export.pass_base import PassType -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.quantizer import Quantizer from torch.export import export, ExportedProgram from torch.testing import FileCheck from torch.utils._pytree import tree_flatten +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer.quantizer import Quantizer class Stage(ABC): diff --git a/docs/source/backends-coreml.md b/docs/source/backends-coreml.md index 29a4b331be6..1292c80148d 100644 --- a/docs/source/backends-coreml.md +++ b/docs/source/backends-coreml.md @@ -104,7 +104,7 @@ import torchvision.models as models from torchvision.models.mobilenetv2 import MobileNet_V2_Weights from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer from executorch.backends.apple.coreml.partition import CoreMLPartitioner -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from executorch.exir import to_edge_transform_and_lower from executorch.backends.apple.coreml.compiler import CoreMLBackend diff --git a/docs/source/backends-xnnpack.md b/docs/source/backends-xnnpack.md index db1c055dc9c..85919952988 100644 --- a/docs/source/backends-xnnpack.md +++ b/docs/source/backends-xnnpack.md @@ -1,6 +1,6 @@ # XNNPACK Backend -The XNNPACK delegate is the ExecuTorch solution for CPU execution on mobile CPUs. [XNNPACK](https://github.com/google/XNNPACK/tree/master) is a library that provides optimized kernels for machine learning operators on Arm and x86 CPUs. +The XNNPACK delegate is the ExecuTorch solution for CPU execution on mobile CPUs. [XNNPACK](https://github.com/google/XNNPACK/tree/master) is a library that provides optimized kernels for machine learning operators on Arm and x86 CPUs. ## Features @@ -18,7 +18,7 @@ The XNNPACK delegate is the ExecuTorch solution for CPU execution on mobile CPUs ## Development Requirements -The XNNPACK delegate does not introduce any development system requirements beyond those required by +The XNNPACK delegate does not introduce any development system requirements beyond those required by the core ExecuTorch runtime. ---- @@ -63,7 +63,7 @@ After generating the XNNPACK-delegated .pte, the model can be tested from Python ## Quantization -The XNNPACK delegate can also be used as a backend to execute symmetrically quantized models. To quantize a PyTorch model for the XNNPACK backend, use the `XNNPACKQuantizer`. `Quantizers` are backend specific, which means the `XNNPACKQuantizer` is configured to quantize models to leverage the quantized operators offered by the XNNPACK Library. +The XNNPACK delegate can also be used as a backend to execute symmetrically quantized models. To quantize a PyTorch model for the XNNPACK backend, use the `XNNPACKQuantizer`. `Quantizers` are backend specific, which means the `XNNPACKQuantizer` is configured to quantize models to leverage the quantized operators offered by the XNNPACK Library. ### Supported Quantization Schemes The XNNPACK delegate supports the following quantization schemes: @@ -94,8 +94,8 @@ from torchvision.models.mobilenetv2 import MobileNet_V2_Weights from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import XNNPACKQuantizer from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner from executorch.exir import to_edge_transform_and_lower -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer.xnnpack_quantizer import get_symmetric_quantization_config +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer.xnnpack_quantizer import get_symmetric_quantization_config model = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).eval() sample_inputs = (torch.randn(1, 3, 224, 224), ) diff --git a/docs/source/llm/getting-started.md b/docs/source/llm/getting-started.md index 152162841e4..7d54f4d2dde 100644 --- a/docs/source/llm/getting-started.md +++ b/docs/source/llm/getting-started.md @@ -619,7 +619,7 @@ from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e ``` ```python diff --git a/docs/source/tutorial-xnnpack-delegate-lowering.md b/docs/source/tutorial-xnnpack-delegate-lowering.md index add60a12deb..12793533766 100644 --- a/docs/source/tutorial-xnnpack-delegate-lowering.md +++ b/docs/source/tutorial-xnnpack-delegate-lowering.md @@ -85,7 +85,7 @@ sample_inputs = (torch.randn(1, 3, 224, 224), ) mobilenet_v2 = export_for_training(mobilenet_v2, sample_inputs).module() # 2-stage export for quantization path -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, diff --git a/docs/source/tutorials_source/export-to-executorch-tutorial.py b/docs/source/tutorials_source/export-to-executorch-tutorial.py index de42cb51bce..2ca6a207d17 100644 --- a/docs/source/tutorials_source/export-to-executorch-tutorial.py +++ b/docs/source/tutorials_source/export-to-executorch-tutorial.py @@ -200,7 +200,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: get_symmetric_quantization_config, XNNPACKQuantizer, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config()) prepared_graph = prepare_pt2e(pre_autograd_aten_dialect, quantizer) # type: ignore[arg-type] diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 73fa4b24d4e..25f2a26ccf7 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -56,10 +56,10 @@ from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.extension.export_util.utils import save_pte_program from tabulate import tabulate +from torch.utils.data import DataLoader # Quantize model if required using the standard export quantizaion flow. -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.utils.data import DataLoader +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from ..models import MODEL_NAME_TO_MODEL from ..models.model_factory import EagerModelFactory diff --git a/examples/arm/ethos_u_minimal_example.ipynb b/examples/arm/ethos_u_minimal_example.ipynb index 77be8c22447..146a586d0ab 100644 --- a/examples/arm/ethos_u_minimal_example.ipynb +++ b/examples/arm/ethos_u_minimal_example.ipynb @@ -84,7 +84,7 @@ " EthosUQuantizer,\n", " get_symmetric_quantization_config,\n", ")\n", - "from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e\n", + "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", "\n", "target = \"ethos-u55-128\"\n", "\n", diff --git a/examples/mediatek/aot_utils/oss_utils/utils.py b/examples/mediatek/aot_utils/oss_utils/utils.py index d286a380d5c..bf7b25f07c2 100755 --- a/examples/mediatek/aot_utils/oss_utils/utils.py +++ b/examples/mediatek/aot_utils/oss_utils/utils.py @@ -15,7 +15,7 @@ Precision, ) from executorch.exir.backend.backend_details import CompileSpec -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def build_executorch_binary( diff --git a/examples/mediatek/model_export_scripts/llama.py b/examples/mediatek/model_export_scripts/llama.py index 34e935bb03b..6a098e2a9b1 100644 --- a/examples/mediatek/model_export_scripts/llama.py +++ b/examples/mediatek/model_export_scripts/llama.py @@ -43,7 +43,7 @@ Precision, ) from executorch.exir.backend.backend_details import CompileSpec -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from tqdm import tqdm warnings.filterwarnings("ignore") diff --git a/examples/models/moshi/mimi/test_mimi.py b/examples/models/moshi/mimi/test_mimi.py index 7e2cfb14c49..be3c075913d 100644 --- a/examples/models/moshi/mimi/test_mimi.py +++ b/examples/models/moshi/mimi/test_mimi.py @@ -19,9 +19,9 @@ from huggingface_hub import hf_hub_download from moshi.models import loaders -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export, ExportedProgram from torch.utils._pytree import tree_flatten +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e proxies = { "http": "http://fwdproxy:8080", diff --git a/examples/models/phi-3-mini/export_phi-3-mini.py b/examples/models/phi-3-mini/export_phi-3-mini.py index 11c2f3834eb..246b3ccd6c6 100644 --- a/examples/models/phi-3-mini/export_phi-3-mini.py +++ b/examples/models/phi-3-mini/export_phi-3-mini.py @@ -20,8 +20,8 @@ ) from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config from executorch.exir import to_edge -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import export_for_training +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from transformers import Phi3ForCausalLM diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index 6f7bdac8e15..08a4bde779b 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -81,8 +81,8 @@ from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer -from torch.ao.quantization.observer import MinMaxObserver -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.observer import MinMaxObserver +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e sys.setrecursionlimit(4096) FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" diff --git a/examples/qualcomm/oss_scripts/moshi/mimi.py b/examples/qualcomm/oss_scripts/moshi/mimi.py index 6b59a71ae64..d67d378f0ce 100644 --- a/examples/qualcomm/oss_scripts/moshi/mimi.py +++ b/examples/qualcomm/oss_scripts/moshi/mimi.py @@ -37,7 +37,7 @@ from huggingface_hub import hf_hub_download from moshi.models import loaders -from torch.ao.quantization.observer import MinMaxObserver +from torchao.quantization.pt2e.observer import MinMaxObserver def seed_all(seed): diff --git a/examples/qualcomm/scripts/export_example.py b/examples/qualcomm/scripts/export_example.py index acf8a9ab468..515fdda8b41 100644 --- a/examples/qualcomm/scripts/export_example.py +++ b/examples/qualcomm/scripts/export_example.py @@ -16,7 +16,7 @@ from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.extension.export_util.utils import save_pte_program -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e def main() -> None: diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index d8dab88e998..4ed51a96340 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -19,6 +19,7 @@ import numpy as np import torch +import torchao from executorch.backends.qualcomm.quantizer.quantizer import ( ModuleQConfig, QnnQuantizer, @@ -33,8 +34,8 @@ ) from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass -from torch.ao.quantization.observer import MovingAverageMinMaxObserver -from torch.ao.quantization.quantize_pt2e import ( +from torchao.quantization.pt2e.observer import MovingAverageMinMaxObserver +from torchao.quantization.pt2e.quantize_pt2e import ( convert_pt2e, prepare_pt2e, prepare_qat_pt2e, @@ -231,7 +232,7 @@ def ptq_calibrate(captured_model, quantizer, dataset): def qat_train(ori_model, captured_model, quantizer, dataset): data, targets = dataset - annotated_model = torch.ao.quantization.move_exported_model_to_train( + annotated_model = torchao.quantization.pt2e.move_exported_model_to_train( prepare_qat_pt2e(captured_model, quantizer) ) optimizer = torch.optim.SGD(annotated_model.parameters(), lr=0.00001) @@ -240,7 +241,9 @@ def qat_train(ori_model, captured_model, quantizer, dataset): print(f"Epoch {i}") if i > 3: # Freeze quantizer parameters - annotated_model.apply(torch.ao.quantization.disable_observer) + annotated_model.apply( + torchao.quantization.pt2e.fake_quantize.disable_observer + ) if i > 2: # Freeze batch norm mean and variance estimates annotated_model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) @@ -251,8 +254,8 @@ def qat_train(ori_model, captured_model, quantizer, dataset): loss.backward() optimizer.step() - return torch.ao.quantization.quantize_pt2e.convert_pt2e( - torch.ao.quantization.move_exported_model_to_eval(annotated_model) + return torchao.quantization.quantize_pt2e.convert_pt2e( + torchao.quantization.move_exported_model_to_eval(annotated_model) ) diff --git a/examples/xnnpack/quantization/example.py b/examples/xnnpack/quantization/example.py index 90a6b94d02b..93831ab8252 100644 --- a/examples/xnnpack/quantization/example.py +++ b/examples/xnnpack/quantization/example.py @@ -29,7 +29,7 @@ _convert_to_reference_decomposed_fx, prepare_fx, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from ...models import MODEL_NAME_TO_MODEL from ...models.model_factory import EagerModelFactory diff --git a/examples/xnnpack/quantization/utils.py b/examples/xnnpack/quantization/utils.py index 9e49f15a99d..d7648daf5da 100644 --- a/examples/xnnpack/quantization/utils.py +++ b/examples/xnnpack/quantization/utils.py @@ -11,7 +11,7 @@ XNNPACKQuantizer, ) -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from .. import QuantType diff --git a/exir/backend/test/demos/test_xnnpack_qnnpack.py b/exir/backend/test/demos/test_xnnpack_qnnpack.py index 5cbd7f7f659..6038f936ca0 100644 --- a/exir/backend/test/demos/test_xnnpack_qnnpack.py +++ b/exir/backend/test/demos/test_xnnpack_qnnpack.py @@ -28,13 +28,13 @@ _load_for_executorch_from_buffer, ) from executorch.extension.pytree import tree_flatten -from torch.ao.quantization.backend_config.executorch import ( - get_executorch_backend_config, -) -from torch.ao.quantization.observer import ( +from torch.ao.quantization import ( default_dynamic_quant_observer, default_per_channel_weight_observer, ) +from torch.ao.quantization.backend_config.executorch import ( + get_executorch_backend_config, +) from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping from torch.ao.quantization.quantize_fx import ( _convert_to_reference_decomposed_fx, diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index b87ae2dfb58..4f7bccbbf92 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -41,10 +41,6 @@ from torch.ao.quantization.backend_config.executorch import ( get_executorch_backend_config, ) -from torch.ao.quantization.observer import ( - default_dynamic_quant_observer, - default_per_channel_weight_observer, -) from torch.ao.quantization.qconfig_mapping import QConfig, QConfigMapping from torch.ao.quantization.quantize_fx import ( _convert_to_reference_decomposed_fx, @@ -55,6 +51,10 @@ from torch.export.exported_program import ExportGraphSignature from torch.fx import Graph, GraphModule, Node from torch.nn import functional as F +from torchao.quantization.pt2e.observer import ( + default_dynamic_quant_observer, + default_per_channel_weight_observer, +) torch.ops.load_library("//executorch/kernels/portable:custom_ops_generated_lib") diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 6618c729987..0e32b6c6870 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -71,9 +71,6 @@ from functorch.experimental import control_flow from torch import nn - -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer import QuantizationSpec from torch.export import export from torch.export.graph_signature import InputKind, InputSpec, TensorArgument from torch.fx import GraphModule, subgraph_rewriter @@ -82,6 +79,9 @@ from torch.testing import FileCheck from torch.utils import _pytree as pytree +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import QuantizationSpec + # pyre-ignore def collect_ops(gm: torch.fx.GraphModule): @@ -1168,7 +1168,7 @@ def forward(self, query, key, value): ).module() # 8w16a quantization - from torch.ao.quantization.observer import ( + from torchao.quantization.pt2e.observer import ( MinMaxObserver, PerChannelMinMaxObserver, ) diff --git a/exir/tests/test_quantization.py b/exir/tests/test_quantization.py index 0a0a85077bb..2c66787c7c4 100644 --- a/exir/tests/test_quantization.py +++ b/exir/tests/test_quantization.py @@ -19,18 +19,17 @@ from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.spec_prop_pass import SpecPropPass from torch.ao.ns.fx.utils import compute_sqnr -from torch.ao.quantization import QConfigMapping # @manual from torch.ao.quantization.backend_config import get_executorch_backend_config from torch.ao.quantization.qconfig import default_per_channel_symmetric_qnnpack_qconfig -from torch.ao.quantization.quantize_fx import prepare_fx -from torch.ao.quantization.quantize_pt2e import ( +from torch.ao.quantization.qconfig_mapping import QConfigMapping +from torch.ao.quantization.quantize_fx import ( # @manual _convert_to_reference_decomposed_fx, - convert_pt2e, - prepare_pt2e, + prepare_fx, ) from torch.export import export from torch.testing import FileCheck from torch.testing._internal.common_quantized import override_quantized_engine +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e # load executorch out variant ops torch.ops.load_library("//executorch/kernels/quantized:custom_ops_generated_lib") diff --git a/exir/tests/test_quantize_io_pass.py b/exir/tests/test_quantize_io_pass.py index ddc0294ba68..f670594616a 100644 --- a/exir/tests/test_quantize_io_pass.py +++ b/exir/tests/test_quantize_io_pass.py @@ -20,8 +20,8 @@ QuantizeOutputs, ) from executorch.exir.tensor import get_scalar_type -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.testing import FileCheck +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e op_str = { "q": "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default", diff --git a/export/export.py b/export/export.py index 7ea4de20a9a..8e9417aa965 100644 --- a/export/export.py +++ b/export/export.py @@ -13,10 +13,10 @@ from executorch.runtime import Runtime, Verification from tabulate import tabulate from torch import nn -from torch.ao.quantization import allow_exported_model_train_eval -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e from torch.export import ExportedProgram from torchao.quantization import quantize_ +from torchao.quantization.pt2e import allow_exported_model_train_eval +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.utils import unwrap_tensor_subclass from .recipe import ExportRecipe diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index 323311caeea..3905051ed57 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -35,11 +35,11 @@ from executorch.extension.llm.export.export_passes import RemoveRedundantTransposes from pytorch_tokenizers import get_tokenizer -from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.composable_quantizer import ComposableQuantizer from torch.export import export_for_training, ExportedProgram from torch.nn.attention import SDPBackend +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.pt2e.quantizer.composable_quantizer import ComposableQuantizer from torchao.utils import unwrap_tensor_subclass FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index d7b8b3a92b1..985ff10a396 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -16,8 +16,8 @@ XNNPACKQuantizer, ) -from torch.ao.quantization.quantizer import Quantizer -from torch.ao.quantization.quantizer.embedding_quantizer import EmbeddingQuantizer +from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.pt2e.quantizer.embedding_quantizer import EmbeddingQuantizer FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) @@ -154,7 +154,7 @@ def get_qnn_quantizer( QnnQuantizer, QuantDtype, ) - from torch.ao.quantization.observer import MinMaxObserver + from torchao.quantization.pt2e.observer import MinMaxObserver except ImportError: raise ImportError(