Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,11 @@ python_unittest(
],
typing = True,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/exir:pass_base",
"//pytorch/ao:torchao",
],
)
193 changes: 193 additions & 0 deletions backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,210 @@

# pyre-strict

import inspect
import unittest
from typing import Callable

import torch
from executorch.backends.cadence.aot.graph_builder import GraphBuilder
from executorch.backends.cadence.aot.quantizer import quantizer as quantizer_module
from executorch.backends.cadence.aot.quantizer.patterns import AddmmPattern

from executorch.backends.cadence.aot.quantizer.quantizer import (
CadenceAtenQuantizer,
CadenceDefaultQuantizer,
CadenceFusedConvReluQuantizer,
CadenceNopQuantizer,
CadenceQuantizer,
CadenceW8A32MixedQuantizer,
CadenceWakeWordQuantizer,
CadenceWith16BitConvActivationsQuantizer,
CadenceWith16BitLinearActivationsQuantizer,
CadenceWith16BitMatmulActivationsQuantizer,
CadenceWithLayerNormQuantizer,
CadenceWithSoftmaxQuantizer,
qconfig_A16,
qconfig_A8W8,
)
from executorch.exir.pass_base import NodeMetadata
from parameterized import parameterized
from torch._ops import OpOverload
from torchao.quantization.pt2e.quantizer.quantizer import (
Q_ANNOTATION_KEY,
QuantizationAnnotation,
QuantizationSpec,
)

# Type alias for graph builder functions.
# These functions take a test instance and return a graph module and the target op node.
GraphBuilderFn = Callable[
["QuantizerAnnotationTest"], tuple[torch.fx.GraphModule, torch.fx.Node]
]


# Quantizers intentionally excluded from annotation testing.
# These should be explicitly justified when added.
EXCLUDED_FROM_ANNOTATION_TESTING: set[type[CadenceQuantizer]] = {
CadenceDefaultQuantizer, # TODO: T247438143 Add test coverage
CadenceFusedConvReluQuantizer, # TODO: T247438151 Add test coverage
CadenceNopQuantizer, # No-op quantizer, doesn't annotate anything
CadenceW8A32MixedQuantizer, # TODO: T247438158 Add test coverage
CadenceWakeWordQuantizer, # TODO: T247438162 Add test coverage
CadenceWith16BitConvActivationsQuantizer, # TODO: T247438221 Add test coverage
CadenceWithLayerNormQuantizer, # TODO: T247438410 Add test coverage
CadenceWithSoftmaxQuantizer, # TODO: T247438418 Add test coverage
}


# Test case definitions for quantizer annotation tests.
# Format: (name, graph_builder_fn, quantizer_instance, target_op, expected_output_qspec, expected_input_qspecs)
# Adding a new quantizer test only requires adding a tuple to this list.
QUANTIZER_ANNOTATION_TEST_CASES: list[
tuple[
str,
GraphBuilderFn,
CadenceQuantizer,
OpOverload,
QuantizationSpec,
list[QuantizationSpec],
]
] = [
(
"matmul_A16",
lambda self: self._build_matmul_graph(),
CadenceWith16BitMatmulActivationsQuantizer(),
torch.ops.aten.matmul.default,
qconfig_A16.output_activation,
# For matmul, both inputs are activations
[qconfig_A16.input_activation, qconfig_A16.input_activation],
),
(
"linear_A16",
lambda self: self._build_linear_graph(),
CadenceWith16BitLinearActivationsQuantizer(),
torch.ops.aten.linear.default,
qconfig_A16.output_activation,
# For linear: [input_activation, weight]
[qconfig_A16.input_activation, qconfig_A16.weight],
),
]

# Derive the set of tested quantizer classes from the test cases.
# This ensures TESTED_QUANTIZER_CLASSES stays in sync with actual tests.
TESTED_QUANTIZER_CLASSES: set[type[CadenceQuantizer]] = {
type(case[2]) for case in QUANTIZER_ANNOTATION_TEST_CASES
}


class QuantizerAnnotationTest(unittest.TestCase):
"""Unit tests for verifying quantizer annotations are correctly applied."""

def _build_matmul_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a matmul operation."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(4, 8))
y = builder.placeholder("y", torch.randn(8, 4))
matmul = builder.call_operator(
op=torch.ops.aten.matmul.default,
args=(x, y),
meta=NodeMetadata(
{"source_fn_stack": [("matmul", torch.ops.aten.matmul.default)]}
),
)
builder.output([matmul])
gm = builder.get_graph_module()

matmul_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.matmul.default,
)
self.assertEqual(len(matmul_nodes), 1, "Should find exactly one matmul node")
return gm, matmul_nodes[0]

def _build_linear_graph(self) -> tuple[torch.fx.GraphModule, torch.fx.Node]:
"""Build a simple graph with a linear operation (no bias)."""
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 10))
weight = builder.placeholder("weight", torch.randn(5, 10))
linear = builder.call_operator(
op=torch.ops.aten.linear.default,
args=(x, weight),
meta=NodeMetadata(
{"source_fn_stack": [("linear", torch.ops.aten.linear.default)]}
),
)
builder.output([linear])
gm = builder.get_graph_module()

linear_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.linear.default,
)
self.assertEqual(len(linear_nodes), 1, "Should find exactly one linear node")
return gm, linear_nodes[0]

@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
def test_quantizer_annotation(
self,
name: str,
graph_builder_fn: GraphBuilderFn,
quantizer: CadenceQuantizer,
target: OpOverload,
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The target parameter is defined in the test method signature but is never used in the test implementation (lines 161-186). If this parameter is intended for documentation or future use, consider either:

  1. Using it to validate that the op_node's target matches the expected target
  2. Removing it from the test case definition if it's not needed

Example validation could be: self.assertEqual(op_node.target, target, "Operation target mismatch")

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The target parameter is defined but not used in the test method body. Consider adding a verification that op_node.target == target to ensure the test is validating the expected operation type. This would make the test more robust by explicitly checking that the graph builder function created the expected operation.

Copilot uses AI. Check for mistakes.
expected_output_qspec: QuantizationSpec,
expected_input_qspecs: list[QuantizationSpec],
) -> None:
"""Parameterized test for quantizer annotations."""
gm, op_node = graph_builder_fn(self)

quantizer.annotate(gm)

annotation: QuantizationAnnotation = op_node.meta[Q_ANNOTATION_KEY]
self.assertTrue(annotation._annotated)

# Verify output annotation
self.assertEqual(annotation.output_qspec, expected_output_qspec)

# Verify input annotations
self.assertEqual(len(annotation.input_qspec_map), len(expected_input_qspecs))
for i, (input_node, input_qspec) in enumerate(
annotation.input_qspec_map.items()
):
expected_arg = op_node.args[i]
assert isinstance(expected_arg, torch.fx.Node)
Copy link

Copilot AI Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a bare assert statement in a test is inconsistent with unittest conventions. Consider using self.assertIsInstance(expected_arg, torch.fx.Node) instead. This provides better error messages and is consistent with the rest of the test framework.

Suggested change
assert isinstance(expected_arg, torch.fx.Node)
self.assertIsInstance(expected_arg, torch.fx.Node)

Copilot uses AI. Check for mistakes.
self.assertEqual(
input_node,
expected_arg,
f"Input node mismatch at index {i}",
)
self.assertEqual(
input_qspec,
expected_input_qspecs[i],
f"Input qspec mismatch at index {i}",
)

def test_all_quantizers_have_annotation_tests(self) -> None:
"""Ensure every CadenceQuantizer subclass is either tested or explicitly excluded."""
# Get all CadenceQuantizer subclasses defined in the quantizer module
all_quantizers: set[type[CadenceQuantizer]] = set()
for _, obj in inspect.getmembers(quantizer_module, inspect.isclass):
if (
issubclass(obj, CadenceQuantizer)
and obj is not CadenceQuantizer
and obj.__module__ == quantizer_module.__name__
):
all_quantizers.add(obj)

# Check for missing tests
untested = (
all_quantizers - TESTED_QUANTIZER_CLASSES - EXCLUDED_FROM_ANNOTATION_TESTING
)
if untested:
untested_names = sorted(cls.__name__ for cls in untested)
self.fail(
f"The following CadenceQuantizer subclasses are not tested in "
f"test_quantizer_annotation and not in EXCLUDED_FROM_ANNOTATION_TESTING: "
f"{untested_names}. Please add test cases or explicitly exclude them."
)


class QuantizerOpsPreserveTest(unittest.TestCase):
Expand Down
Loading