Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/xnnpack/operators/op_avg_pooling2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
XNNGraph,
XNode,
)
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS


@register_node_visitor
Expand Down Expand Up @@ -67,7 +68,7 @@ def define_node(
dilation_width=0, # Unused
input_id=input_id,
output_id=output_id,
flags=0,
flags=XNN_FLAG_KEEP_DIMS,
),
debug_handle=debug_handle,
)
Expand Down
3 changes: 2 additions & 1 deletion backends/xnnpack/operators/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
XNNMaxPooling2d,
XNode,
)
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS


@register_node_visitor
Expand Down Expand Up @@ -80,7 +81,7 @@ def define_node(
kwargs["dilation_height"] = dilation[0]
kwargs["dilation_width"] = dilation[1]

kwargs["flags"] = 0
kwargs["flags"] = XNN_FLAG_KEEP_DIMS

ser_node = XNode(
xnode_union=XNNMaxPooling2d(
Expand Down
3 changes: 2 additions & 1 deletion backends/xnnpack/operators/op_mean_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
XNNGraph,
XNode,
)
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS


@register_node_visitor
Expand Down Expand Up @@ -70,7 +71,7 @@ def define_node(

ser_node = XNode(
xnode_union=XNNGlobalAvgPooling2d(
input_id=input_id, output_id=output_id, flags=0
input_id=input_id, output_id=output_id, flags=XNN_FLAG_KEEP_DIMS
),
debug_handle=debug_handle,
)
Expand Down
33 changes: 30 additions & 3 deletions backends/xnnpack/test/ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ def test_fp32_linear(self):
num_batch_dims=num_batch_dims,
)

def test_qc8_linear(self):
for use_bias in (True, False):
for num_batch_dims in range(1, 3):
self._test_linear(
lambda in_size, out_size: torch.nn.Linear(
in_size, out_size, bias=use_bias # noqa
),
uses_bias=use_bias,
quant_type="per_channel",
num_batch_dims=num_batch_dims,
)

def test_fp32_addmm(self):
"""
Note that the ConvertToLinear pass requires the weight matrix to be transposed.
Expand Down Expand Up @@ -107,7 +119,7 @@ def forward(self, x):
),
num_batch_dims=num_batch_dims,
uses_bias=use_bias,
quant=True,
quant_type="per_tensor",
)

def test_qs8_linear(self):
Expand All @@ -119,6 +131,7 @@ def test_qs8_linear(self):
),
uses_bias=use_bias,
num_batch_dims=num_batch_dims,
quant_type="per_tensor",
)

@unittest.skip("XNNPACK currently only supports per-channel dynamic quantization.")
Expand Down Expand Up @@ -726,7 +739,7 @@ def _test_linear(
make_module,
uses_bias,
num_batch_dims=1,
quant=False,
quant_type=None,
dtype: torch.dtype = torch.float,
atol=1e-03,
):
Expand All @@ -746,6 +759,8 @@ def _test_linear(
input_sizes = [4, 37, 17]
output_sizes = [4, 17, 37]

quant = quant_type is not None

"""
Note that torch.nn.Linear maps to aten.mm.default (no bias) or aten.addmm.default (bias),
which ares then transformed into aten.linear.default by the ConvertToLinear pass.
Expand All @@ -769,7 +784,19 @@ def _test_linear(
tester = Tester(module, inputs, dynamic_shapes=dynamic_shape)

if quant:
tester.quantize()
if quant_type == "per_channel":
quant_config = get_symmetric_quantization_config(
is_per_channel=True,
is_dynamic=False,
)
elif quant_type == "per_tensor":
quant_config = get_symmetric_quantization_config(
is_per_channel=False,
is_dynamic=False,
)
else:
raise ValueError(f"Unsupported quant type {quant_type}")
tester.quantize(Quantize(quantization_config=quant_config))

tester.export()
tester.check_count({aten_op: 1})
Expand Down
12 changes: 8 additions & 4 deletions backends/xnnpack/utils/xnnpack_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,25 @@
UINT32_MAX = 4294967295
XNN_EXTRA_BYTES = 16
XNN_MAX_TENSOR_DIMS = 6
XNN_FLAG_SPARSE_INFERENCE = 0x00000001
XNN_FLAG_HINT_SPARSE_INFERENCE = XNN_FLAG_SPARSE_INFERENCE
XNN_FLAG_FP16_INFERENCE = 0x00000002
XNN_FLAG_HINT_FP16_INFERENCE = XNN_FLAG_FP16_INFERENCE
XNN_FLAG_HINT_SPARSE_INFERENCE = 0x00000001
XNN_FLAG_HINT_FP16_INFERENCE = 0x00000002
XNN_FLAG_FORCE_FP16_INFERENCE = 0x00000004
XNN_FLAG_BASIC_PROFILING = 0x00000008
XNN_FLAG_JIT = 0x00000010
XNN_FLAG_DEPTHWISE_CONVOLUTION = 0x00000001
XNN_FLAG_TRANSPOSE_WEIGHTS = 0x00000001
XNN_FLAG_INPUT_NHWC = 0x00000002
XNN_FLAG_TENSORFLOW_SAME_PADDING = 0x00000004
XNN_FLAG_TRANSPOSE_B = XNN_FLAG_TRANSPOSE_WEIGHTS
XNN_FLAG_TRANSPOSE_A = 0x00000002
XNN_FLAG_TENSORFLOW_RESHAPE_2D = 0x00000004
XNN_FLAG_TENSORFLOW_LEGACY_MODE = 0x00000004
XNN_FLAG_FP32_STATIC_WEIGHTS = 0x00000008
XNN_FLAG_ALIGN_CORNERS = 0x00000008
XNN_FLAG_YIELD_WORKERS = 0x00000010
XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER = 0x00000020
XNN_FLAG_KEEP_DIMS = 0x00000040
XNN_EXTRA_QUANTIZATION_PARAMS = 8
XNN_VALUE_FLAG_EXTERNAL_INPUT = 0x00000001
XNN_VALUE_FLAG_EXTERNAL_OUTPUT = 0x00000002
XNN_VALUE_FLAG_PERSISTENT = 0x00000004
Expand Down