-
Notifications
You must be signed in to change notification settings - Fork 751
Adding Test To Ensure All Future Quantizers Are Tested #16099
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d810bc3
d300b0f
7c2f3fb
34d9597
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -6,17 +6,210 @@ | |||||
|
|
||||||
| # pyre-strict | ||||||
|
|
||||||
| import inspect | ||||||
| import unittest | ||||||
| from typing import Callable | ||||||
|
|
||||||
| import torch | ||||||
| from executorch.backends.cadence.aot.graph_builder import GraphBuilder | ||||||
| from executorch.backends.cadence.aot.quantizer import quantizer as quantizer_module | ||||||
| from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern | ||||||
|
|
||||||
| from executorch.backends.cadence.aot.quantizer.quantizer import ( | ||||||
| CadenceAtenQuantizer, | ||||||
| CadenceDefaultQuantizer, | ||||||
| CadenceFusedConvReluQuantizer, | ||||||
| CadenceNopQuantizer, | ||||||
| CadenceQuantizer, | ||||||
| CadenceW8A32MixedQuantizer, | ||||||
| CadenceWakeWordQuantizer, | ||||||
| CadenceWith16BitConvActivationsQuantizer, | ||||||
| CadenceWith16BitLinearActivationsQuantizer, | ||||||
| CadenceWith16BitMatmulActivationsQuantizer, | ||||||
| CadenceWithLayerNormQuantizer, | ||||||
| CadenceWithSoftmaxQuantizer, | ||||||
| qconfig_A16, | ||||||
| qconfig_A8W8, | ||||||
| ) | ||||||
| from executorch.exir.pass_base import NodeMetadata | ||||||
| from parameterized import parameterized | ||||||
| from torch._ops import OpOverload | ||||||
| from torchao.quantization.pt2e.quantizer.quantizer import ( | ||||||
| Q_ANNOTATION_KEY, | ||||||
| QuantizationAnnotation, | ||||||
| QuantizationSpec, | ||||||
| ) | ||||||
|
|
||||||
| # Type alias for graph builder functions. | ||||||
| # These functions take a test instance and return a graph module and the target op node. | ||||||
| GraphBuilderFn = Callable[ | ||||||
| ["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node] | ||||||
| ] | ||||||
|
|
||||||
|
|
||||||
| # Quantizers intentionally excluded from annotation testing. | ||||||
| # These should be explicitly justified when added. | ||||||
| EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = { | ||||||
| CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage | ||||||
| CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage | ||||||
| CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything | ||||||
| CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage | ||||||
| CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage | ||||||
| CadenceWith16BitConvActivationsQuantizer, # TODO: T247438221 Add test coverage | ||||||
| CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage | ||||||
| CadenceWithSoftmaxQuantizer, # TODO: T247438418 Add test coverage | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
| # Test case definitions for quantizer annotation tests. | ||||||
| # Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs) | ||||||
| # Adding a new quantizer test only requires adding a tuple to this list. | ||||||
| QUANTIZER_ANNOTATION_TEST_CASES: list[ | ||||||
| tuple[ | ||||||
| str, | ||||||
| GraphBuilderFn, | ||||||
| CadenceQuantizer, | ||||||
| OpOverload, | ||||||
| QuantizationSpec, | ||||||
| list[QuantizationSpec], | ||||||
| ] | ||||||
| ] = [ | ||||||
| ( | ||||||
| "matmul_A16", | ||||||
| lambda self: self._build_matmul_graph(), | ||||||
| CadenceWith16BitMatmulActivationsQuantizer(), | ||||||
| torch.ops.aten.matmul.default, | ||||||
| qconfig_A16.output_activation, | ||||||
| # For matmul, both inputs are activations | ||||||
| [qconfig_A16.input_activation, qconfig_A16.input_activation], | ||||||
| ), | ||||||
| ( | ||||||
| "linear_A16", | ||||||
| lambda self: self._build_linear_graph(), | ||||||
| CadenceWith16BitLinearActivationsQuantizer(), | ||||||
| torch.ops.aten.linear.default, | ||||||
| qconfig_A16.output_activation, | ||||||
| # For linear: [input_activation, weight] | ||||||
| [qconfig_A16.input_activation, qconfig_A16.weight], | ||||||
| ), | ||||||
| ] | ||||||
|
|
||||||
| # Derive the set of tested quantizer classes from the test cases. | ||||||
| # This ensures TESTED_QUANTIZER_CLASSES stays in sync with actual tests. | ||||||
| TESTED_QUANTIZER_CLASSES: set[type[CadenceQuantizer]] = { | ||||||
| type(case[2]) for case in QUANTIZER_ANNOTATION_TEST_CASES | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
| class QuantizerAnnotationTest(unittest.TestCase): | ||||||
| """Unit tests for verifying quantizer annotations are correctly applied.""" | ||||||
|
|
||||||
| def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: | ||||||
| """Build a simple graph with a matmul operation.""" | ||||||
| builder = GraphBuilder() | ||||||
| x = builder.placeholder("x", torch.randn(4, 8)) | ||||||
| y = builder.placeholder("y", torch.randn(8, 4)) | ||||||
| matmul = builder.call_operator( | ||||||
| op=torch.ops.aten.matmul.default, | ||||||
| args=(x, y), | ||||||
| meta=NodeMetadata( | ||||||
| {"source_fn_stack": [("matmul", torch.ops.aten.matmul.default)]} | ||||||
| ), | ||||||
| ) | ||||||
| builder.output([matmul]) | ||||||
| gm = builder.get_graph_module() | ||||||
|
|
||||||
| matmul_nodes = gm.graph.find_nodes( | ||||||
| op="call_function", | ||||||
| target=torch.ops.aten.matmul.default, | ||||||
| ) | ||||||
| self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node") | ||||||
| return gm, matmul_nodes[0] | ||||||
|
|
||||||
| def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]: | ||||||
| """Build a simple graph with a linear operation (no bias).""" | ||||||
| builder = GraphBuilder() | ||||||
| x = builder.placeholder("x", torch.randn(1, 10)) | ||||||
| weight = builder.placeholder("weight", torch.randn(5, 10)) | ||||||
| linear = builder.call_operator( | ||||||
| op=torch.ops.aten.linear.default, | ||||||
| args=(x, weight), | ||||||
| meta=NodeMetadata( | ||||||
| {"source_fn_stack": [("linear", torch.ops.aten.linear.default)]} | ||||||
| ), | ||||||
| ) | ||||||
| builder.output([linear]) | ||||||
| gm = builder.get_graph_module() | ||||||
|
|
||||||
| linear_nodes = gm.graph.find_nodes( | ||||||
| op="call_function", | ||||||
| target=torch.ops.aten.linear.default, | ||||||
| ) | ||||||
| self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node") | ||||||
| return gm, linear_nodes[0] | ||||||
|
|
||||||
| @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) | ||||||
| def test_quantizer_annotation( | ||||||
| self, | ||||||
| name: str, | ||||||
| graph_builder_fn: GraphBuilderFn, | ||||||
| quantizer: CadenceQuantizer, | ||||||
| target: OpOverload, | ||||||
|
||||||
| expected_output_qspec: QuantizationSpec, | ||||||
| expected_input_qspecs: list[QuantizationSpec], | ||||||
| ) -> None: | ||||||
| """Parameterized test for quantizer annotations.""" | ||||||
| gm, op_node = graph_builder_fn(self) | ||||||
|
|
||||||
| quantizer.annotate(gm) | ||||||
|
|
||||||
| annotation: QuantizationAnnotation = op_node.meta[Q_ANNOTATION_KEY] | ||||||
| self.assertTrue(annotation._annotated) | ||||||
|
|
||||||
| # Verify output annotation | ||||||
| self.assertEqual(annotation.output_qspec, expected_output_qspec) | ||||||
|
|
||||||
| # Verify input annotations | ||||||
| self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs)) | ||||||
| for i, (input_node, input_qspec) in enumerate( | ||||||
| annotation.input_qspec_map.items() | ||||||
| ): | ||||||
| expected_arg = op_node.args[i] | ||||||
| assert isinstance(expected_arg, torch.fx.Node) | ||||||
|
||||||
| assert isinstance(expected_arg, torch.fx.Node) | |
| self.assertIsInstance(expected_arg, torch.fx.Node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
targetparameter is defined in the test method signature but is never used in the test implementation (lines 161-186). If this parameter is intended for documentation or future use, consider either:Example validation could be:
self.assertEqual(op_node.target, target, "Operation target mismatch")