Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/xnnpack/operators/op_avg_pooling2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
XNNGraph,
XNode,
)
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS


@register_node_visitor
Expand Down Expand Up @@ -67,7 +68,7 @@ def define_node(
dilation_width=0, # Unused
input_id=input_id,
output_id=output_id,
flags=0,
flags=XNN_FLAG_KEEP_DIMS,
),
debug_handle=debug_handle,
)
Expand Down
3 changes: 2 additions & 1 deletion backends/xnnpack/operators/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
XNNMaxPooling2d,
XNode,
)
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS


@register_node_visitor
Expand Down Expand Up @@ -80,7 +81,7 @@ def define_node(
kwargs["dilation_height"] = dilation[0]
kwargs["dilation_width"] = dilation[1]

kwargs["flags"] = 0
kwargs["flags"] = XNN_FLAG_KEEP_DIMS

ser_node = XNode(
xnode_union=XNNMaxPooling2d(
Expand Down
3 changes: 2 additions & 1 deletion backends/xnnpack/operators/op_mean_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
XNNGraph,
XNode,
)
from executorch.backends.xnnpack.utils.xnnpack_constants import XNN_FLAG_KEEP_DIMS


@register_node_visitor
Expand Down Expand Up @@ -70,7 +71,7 @@ def define_node(

ser_node = XNode(
xnode_union=XNNGlobalAvgPooling2d(
input_id=input_id, output_id=output_id, flags=0
input_id=input_id, output_id=output_id, flags=XNN_FLAG_KEEP_DIMS
),
debug_handle=debug_handle,
)
Expand Down
12 changes: 8 additions & 4 deletions backends/xnnpack/utils/xnnpack_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,25 @@
UINT32_MAX = 4294967295
XNN_EXTRA_BYTES = 16
XNN_MAX_TENSOR_DIMS = 6
XNN_FLAG_SPARSE_INFERENCE = 0x00000001
XNN_FLAG_HINT_SPARSE_INFERENCE = XNN_FLAG_SPARSE_INFERENCE
XNN_FLAG_FP16_INFERENCE = 0x00000002
XNN_FLAG_HINT_FP16_INFERENCE = XNN_FLAG_FP16_INFERENCE
XNN_FLAG_HINT_SPARSE_INFERENCE = 0x00000001
XNN_FLAG_HINT_FP16_INFERENCE = 0x00000002
XNN_FLAG_FORCE_FP16_INFERENCE = 0x00000004
XNN_FLAG_BASIC_PROFILING = 0x00000008
XNN_FLAG_JIT = 0x00000010
XNN_FLAG_DEPTHWISE_CONVOLUTION = 0x00000001
XNN_FLAG_TRANSPOSE_WEIGHTS = 0x00000001
XNN_FLAG_INPUT_NHWC = 0x00000002
XNN_FLAG_TENSORFLOW_SAME_PADDING = 0x00000004
XNN_FLAG_TRANSPOSE_B = XNN_FLAG_TRANSPOSE_WEIGHTS
XNN_FLAG_TRANSPOSE_A = 0x00000002
XNN_FLAG_TENSORFLOW_RESHAPE_2D = 0x00000004
XNN_FLAG_TENSORFLOW_LEGACY_MODE = 0x00000004
XNN_FLAG_FP32_STATIC_WEIGHTS = 0x00000008
XNN_FLAG_ALIGN_CORNERS = 0x00000008
XNN_FLAG_YIELD_WORKERS = 0x00000010
XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER = 0x00000020
XNN_FLAG_KEEP_DIMS = 0x00000040
XNN_EXTRA_QUANTIZATION_PARAMS = 8
XNN_VALUE_FLAG_EXTERNAL_INPUT = 0x00000001
XNN_VALUE_FLAG_EXTERNAL_OUTPUT = 0x00000002
XNN_VALUE_FLAG_PERSISTENT = 0x00000004
Expand Down