Skip to content

Commit 98a2371

Browse files
RahulC7facebook-github-bot
authored andcommitted
Adding Test To Ensure All Future Quantizers Are Tested
Summary: We first create a list of quantizers that are currently not tested(we'll slowly reduce this to 0), and then we create a test to ensure that all future quantizers get tested using this framework. In order to do this, we needed to refactor how the current test is setup, specifically the parameterization. Reviewed By: hsharma35 Differential Revision: D88055443
1 parent 9637b08 commit 98a2371

File tree

1 file changed

+99
-35
lines changed

1 file changed

+99
-35
lines changed

backends/cadence/aot/tests/test_quantizer_ops.py

Lines changed: 99 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,32 @@
66

77
# pyre-strict
88

9+
import inspect
910
import unittest
1011
from typing import Callable
1112

1213
import torch
1314
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
15+
from executorch.backends.cadence.aot.quantizer import quantizer as quantizer_module
1416
from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern
17+
from executorch.exir.pass_base import NodeMetadata
1518

1619
from executorch.backends.cadence.aot.quantizer.quantizer import (
1720
CadenceAtenQuantizer,
1821
CadenceDefaultQuantizer,
22+
CadenceFusedConvReluQuantizer,
23+
CadenceNopQuantizer,
1924
CadenceQuantizer,
2025
CadenceW8A32MixedQuantizer,
26+
CadenceWakeWordQuantizer,
27+
CadenceWith16BitConvActivationsQuantizer,
2128
CadenceWith16BitLinearActivationsQuantizer,
2229
CadenceWith16BitMatmulActivationsQuantizer,
30+
CadenceWithLayerNormQuantizer,
31+
CadenceWithSoftmaxQuantizer,
2332
qconfig_A16,
2433
qconfig_A8W8,
2534
)
26-
from executorch.exir.pass_base import NodeMetadata
2735
from parameterized import parameterized
2836
from torch._ops import OpOverload
2937
from torchao.quantization.pt2e.quantizer.quantizer import (
@@ -32,12 +40,67 @@
3240
QuantizationSpec,
3341
)
3442

35-
# Type alias for graph builder functions
43+
# Type alias for graph builder functions.
44+
# These functions take a test instance and return a graph module and the target op node.
3645
GraphBuilderFn = Callable[
3746
["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node]
3847
]
3948

4049

50+
# Quantizers intentionally excluded from annotation testing.
51+
# These should be explicitly justified when added.
52+
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
53+
CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage
54+
CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage
55+
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
56+
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
57+
CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage
58+
CadenceWith16BitConvActivationsQuantizer, # TODO: T247438221 Add test coverage
59+
CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage
60+
CadenceWithSoftmaxQuantizer, # TODO: T247438418 Add test coverage
61+
}
62+
63+
64+
# Test case definitions for quantizer annotation tests.
65+
# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
66+
# Adding a new quantizer test only requires adding a tuple to this list.
67+
QUANTIZER_ANNOTATION_TEST_CASES: list[
68+
tuple[
69+
str,
70+
GraphBuilderFn,
71+
CadenceQuantizer,
72+
OpOverload,
73+
QuantizationSpec,
74+
list[QuantizationSpec],
75+
]
76+
] = [
77+
(
78+
"matmul_A16",
79+
lambda self: self._build_matmul_graph(),
80+
CadenceWith16BitMatmulActivationsQuantizer(),
81+
torch.ops.aten.matmul.default,
82+
qconfig_A16.output_activation,
83+
# For matmul, both inputs are activations
84+
[qconfig_A16.input_activation, qconfig_A16.input_activation],
85+
),
86+
(
87+
"linear_A16",
88+
lambda self: self._build_linear_graph(),
89+
CadenceWith16BitLinearActivationsQuantizer(),
90+
torch.ops.aten.linear.default,
91+
qconfig_A16.output_activation,
92+
# For linear: [input_activation, weight]
93+
[qconfig_A16.input_activation, qconfig_A16.weight],
94+
),
95+
]
96+
97+
# Derive the set of tested quantizer classes from the test cases.
98+
# This ensures TESTED_QUANTIZER_CLASSES stays in sync with actual tests.
99+
TESTED_QUANTIZER_CLASSES: set[type[CadenceQuantizer]] = {
100+
type(case[2]) for case in QUANTIZER_ANNOTATION_TEST_CASES
101+
}
102+
103+
41104
class QuantizerAnnotationTest(unittest.TestCase):
42105
"""Unit tests for verifying quantizer annotations are correctly applied."""
43106

@@ -85,28 +148,7 @@ def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
85148
self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node")
86149
return gm, linear_nodes[0]
87150

88-
@parameterized.expand(
89-
[
90-
(
91-
"matmul_A16",
92-
lambda self: self._build_matmul_graph(),
93-
CadenceWith16BitMatmulActivationsQuantizer(),
94-
torch.ops.aten.matmul.default,
95-
qconfig_A16.output_activation,
96-
# For matmul, both inputs are activations
97-
[qconfig_A16.input_activation, qconfig_A16.input_activation],
98-
),
99-
(
100-
"linear_A16",
101-
lambda self: self._build_linear_graph(),
102-
CadenceWith16BitLinearActivationsQuantizer(),
103-
torch.ops.aten.linear.default,
104-
qconfig_A16.output_activation,
105-
# For linear: [input_activation, weight]
106-
[qconfig_A16.input_activation, qconfig_A16.weight],
107-
),
108-
]
109-
)
151+
@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
110152
def test_quantizer_annotation(
111153
self,
112154
name: str,
@@ -128,23 +170,45 @@ def test_quantizer_annotation(
128170
self.assertEqual(annotation.output_qspec, expected_output_qspec)
129171

130172
# Verify input annotations
131-
# Build actual_specs in the fixed order defined by op_node.args
132173
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
133-
actual_specs = [
134-
annotation.input_qspec_map[op_node.args[i]]
135-
for i in range(len(expected_input_qspecs))
136-
]
137-
138-
# Compare expected vs actual specs
139-
for i, (expected, actual) in enumerate(
140-
zip(expected_input_qspecs, actual_specs)
174+
for i, (input_node, input_qspec) in enumerate(
175+
annotation.input_qspec_map.items()
141176
):
142177
self.assertEqual(
143-
actual,
144-
expected,
178+
input_node,
179+
op_node.args[i],
180+
f"Input node mismatch at index {i}",
181+
)
182+
self.assertEqual(
183+
input_qspec,
184+
expected_input_qspecs[i],
145185
f"Input qspec mismatch at index {i}",
146186
)
147187

188+
def test_all_quantizers_have_annotation_tests(self) -> None:
189+
"""Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""
190+
# Get all CadenceQuantizer subclasses defined in the quantizer module
191+
all_quantizers: set[type[CadenceQuantizer]] = set()
192+
for _, obj in inspect.getmembers(quantizer_module, inspect.isclass):
193+
if (
194+
issubclass(obj, CadenceQuantizer)
195+
and obj is not CadenceQuantizer
196+
and obj.__module__ == quantizer_module.__name__
197+
):
198+
all_quantizers.add(obj)
199+
200+
# Check for missing tests
201+
untested = (
202+
all_quantizers - TESTED_QUANTIZER_CLASSES - EXCLUDED_FROM_ANNOTATION_TESTING
203+
)
204+
if untested:
205+
untested_names = sorted(cls.__name__ for cls in untested)
206+
self.fail(
207+
f"The following CadenceQuantizer subclasses are not tested in "
208+
f"test_quantizer_annotation and not in EXCLUDED_FROM_ANNOTATION_TESTING: "
209+
f"{untested_names}. Please add test cases or explicitly exclude them."
210+
)
211+
148212

149213
class QuantizerOpsPreserveTest(unittest.TestCase):
150214
def test_mixed_w8a32_ops_to_preserve(self) -> None:

0 commit comments

Comments
 (0)