66
77# pyre-strict
88
9+ import inspect
910import unittest
1011from typing import Callable
1112
1213import torch
1314from executorch .backends .cadence .aot .graph_builder import GraphBuilder
15+ from executorch .backends .cadence .aot .quantizer import quantizer as quantizer_module
1416from executorch .backends .cadence .aot .quantizer .patterns import AddmmPattern
17+ from executorch .exir .pass_base import NodeMetadata
1518
1619from executorch .backends .cadence .aot .quantizer .quantizer import (
1720 CadenceAtenQuantizer ,
1821 CadenceDefaultQuantizer ,
22+ CadenceFusedConvReluQuantizer ,
23+ CadenceNopQuantizer ,
1924 CadenceQuantizer ,
2025 CadenceW8A32MixedQuantizer ,
26+ CadenceWakeWordQuantizer ,
27+ CadenceWith16BitConvActivationsQuantizer ,
2128 CadenceWith16BitLinearActivationsQuantizer ,
2229 CadenceWith16BitMatmulActivationsQuantizer ,
30+ CadenceWithLayerNormQuantizer ,
31+ CadenceWithSoftmaxQuantizer ,
2332 qconfig_A16 ,
2433 qconfig_A8W8 ,
2534)
26- from executorch .exir .pass_base import NodeMetadata
2735from parameterized import parameterized
2836from torch ._ops import OpOverload
2937from torchao .quantization .pt2e .quantizer .quantizer import (
3240 QuantizationSpec ,
3341)
3442
35- # Type alias for graph builder functions
43+ # Type alias for graph builder functions.
44+ # These functions take a test instance and return a graph module and the target op node.
3645GraphBuilderFn = Callable [
3746 ["QuantizerAnnotationTest" ], tuple [torch .fx .GraphModule , torch .fx .Node ]
3847]
3948
4049
50+ # Quantizers intentionally excluded from annotation testing.
51+ # These should be explicitly justified when added.
52+ EXCLUDED_FROM_ANNOTATION_TESTING : set [type [CadenceQuantizer ]] = {
53+ CadenceDefaultQuantizer , # TODO: T247438143 Add test coverage
54+ CadenceFusedConvReluQuantizer , # TODO: T247438151 Add test coverage
55+ CadenceNopQuantizer , # No-op quantizer, doesn't annotate anything
56+ CadenceW8A32MixedQuantizer , # TODO: T247438158 Add test coverage
57+ CadenceWakeWordQuantizer , # TODO: T247438162 Add test coverage
58+ CadenceWith16BitConvActivationsQuantizer , # TODO: T247438221 Add test coverage
59+ CadenceWithLayerNormQuantizer , # TODO: T247438410 Add test coverage
60+ CadenceWithSoftmaxQuantizer , # TODO: T247438418 Add test coverage
61+ }
62+
63+
64+ # Test case definitions for quantizer annotation tests.
65+ # Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
66+ # Adding a new quantizer test only requires adding a tuple to this list.
67+ QUANTIZER_ANNOTATION_TEST_CASES : list [
68+ tuple [
69+ str ,
70+ GraphBuilderFn ,
71+ CadenceQuantizer ,
72+ OpOverload ,
73+ QuantizationSpec ,
74+ list [QuantizationSpec ],
75+ ]
76+ ] = [
77+ (
78+ "matmul_A16" ,
79+ lambda self : self ._build_matmul_graph (),
80+ CadenceWith16BitMatmulActivationsQuantizer (),
81+ torch .ops .aten .matmul .default ,
82+ qconfig_A16 .output_activation ,
83+ # For matmul, both inputs are activations
84+ [qconfig_A16 .input_activation , qconfig_A16 .input_activation ],
85+ ),
86+ (
87+ "linear_A16" ,
88+ lambda self : self ._build_linear_graph (),
89+ CadenceWith16BitLinearActivationsQuantizer (),
90+ torch .ops .aten .linear .default ,
91+ qconfig_A16 .output_activation ,
92+ # For linear: [input_activation, weight]
93+ [qconfig_A16 .input_activation , qconfig_A16 .weight ],
94+ ),
95+ ]
96+
97+ # Derive the set of tested quantizer classes from the test cases.
98+ # This ensures TESTED_QUANTIZER_CLASSES stays in sync with actual tests.
99+ TESTED_QUANTIZER_CLASSES : set [type [CadenceQuantizer ]] = {
100+ type (case [2 ]) for case in QUANTIZER_ANNOTATION_TEST_CASES
101+ }
102+
103+
41104class QuantizerAnnotationTest (unittest .TestCase ):
42105 """Unit tests for verifying quantizer annotations are correctly applied."""
43106
@@ -85,28 +148,7 @@ def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
85148 self .assertEqual (len (linear_nodes ), 1 , "Should find exactly one linear node" )
86149 return gm , linear_nodes [0 ]
87150
88- @parameterized .expand (
89- [
90- (
91- "matmul_A16" ,
92- lambda self : self ._build_matmul_graph (),
93- CadenceWith16BitMatmulActivationsQuantizer (),
94- torch .ops .aten .matmul .default ,
95- qconfig_A16 .output_activation ,
96- # For matmul, both inputs are activations
97- [qconfig_A16 .input_activation , qconfig_A16 .input_activation ],
98- ),
99- (
100- "linear_A16" ,
101- lambda self : self ._build_linear_graph (),
102- CadenceWith16BitLinearActivationsQuantizer (),
103- torch .ops .aten .linear .default ,
104- qconfig_A16 .output_activation ,
105- # For linear: [input_activation, weight]
106- [qconfig_A16 .input_activation , qconfig_A16 .weight ],
107- ),
108- ]
109- )
151+ @parameterized .expand (QUANTIZER_ANNOTATION_TEST_CASES )
110152 def test_quantizer_annotation (
111153 self ,
112154 name : str ,
@@ -128,23 +170,45 @@ def test_quantizer_annotation(
128170 self .assertEqual (annotation .output_qspec , expected_output_qspec )
129171
130172 # Verify input annotations
131- # Build actual_specs in the fixed order defined by op_node.args
132173 self .assertEqual (len (annotation .input_qspec_map ), len (expected_input_qspecs ))
133- actual_specs = [
134- annotation .input_qspec_map [op_node .args [i ]]
135- for i in range (len (expected_input_qspecs ))
136- ]
137-
138- # Compare expected vs actual specs
139- for i , (expected , actual ) in enumerate (
140- zip (expected_input_qspecs , actual_specs )
174+ for i , (input_node , input_qspec ) in enumerate (
175+ annotation .input_qspec_map .items ()
141176 ):
142177 self .assertEqual (
143- actual ,
144- expected ,
178+ input_node ,
179+ op_node .args [i ],
180+ f"Input node mismatch at index { i } " ,
181+ )
182+ self .assertEqual (
183+ input_qspec ,
184+ expected_input_qspecs [i ],
145185 f"Input qspec mismatch at index { i } " ,
146186 )
147187
188+ def test_all_quantizers_have_annotation_tests (self ) -> None :
189+ """Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""
190+ # Get all CadenceQuantizer subclasses defined in the quantizer module
191+ all_quantizers : set [type [CadenceQuantizer ]] = set ()
192+ for _ , obj in inspect .getmembers (quantizer_module , inspect .isclass ):
193+ if (
194+ issubclass (obj , CadenceQuantizer )
195+ and obj is not CadenceQuantizer
196+ and obj .__module__ == quantizer_module .__name__
197+ ):
198+ all_quantizers .add (obj )
199+
200+ # Check for missing tests
201+ untested = (
202+ all_quantizers - TESTED_QUANTIZER_CLASSES - EXCLUDED_FROM_ANNOTATION_TESTING
203+ )
204+ if untested :
205+ untested_names = sorted (cls .__name__ for cls in untested )
206+ self .fail (
207+ f"The following CadenceQuantizer subclasses are not tested in "
208+ f"test_quantizer_annotation and not in EXCLUDED_FROM_ANNOTATION_TESTING: "
209+ f"{ untested_names } . Please add test cases or explicitly exclude them."
210+ )
211+
148212
149213class QuantizerOpsPreserveTest (unittest .TestCase ):
150214 def test_mixed_w8a32_ops_to_preserve (self ) -> None :
0 commit comments