Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 60 additions & 6 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,15 +201,20 @@ def __init__(self, func_body: tvm.relay.Function):
from tvm.relay.backend.contrib.ethosu.util import RequantArgs

activation = None
separate_padding = None

if str(func_body.op) in self.activation_map.keys():
activation = func_body
requantize_op = activation.args[0]
else:
requantize_op = func_body
bias_add = requantize_op.args[0]
qnn_conv2d = bias_add.args[0]
if isinstance(qnn_conv2d.args[0], relay.Call) and str(qnn_conv2d.args[0].op) == "nn.pad":
separate_padding = qnn_conv2d.args[0]
data_layout = qnn_conv2d.attrs.data_layout
self.kernel_layout = qnn_conv2d.attrs.kernel_layout

# We consider the weights & biases as params as it should be a Constant
self.weights = TensorParams(
qnn_conv2d.args[QConv2DArgs.WEIGHTS.value],
Expand All @@ -224,8 +229,11 @@ def __init__(self, func_body: tvm.relay.Function):
requantize_op.args[RequantArgs.IFM_SCALE.value],
requantize_op.args[RequantArgs.IFM_ZERO_POINT.value],
)
ifm_tensor = (
separate_padding.args[0] if separate_padding else qnn_conv2d.args[QConv2DArgs.IFM.value]
)
self.ifm = TensorParams(
qnn_conv2d.args[QConv2DArgs.IFM.value],
ifm_tensor,
data_layout,
qnn_conv2d.args[QConv2DArgs.IFM_SCALE.value],
qnn_conv2d.args[QConv2DArgs.IFM_ZERO_POINT.value],
Expand All @@ -237,7 +245,10 @@ def __init__(self, func_body: tvm.relay.Function):
requantize_op.args[RequantArgs.OFM_ZERO_POINT.value],
)
attrs = qnn_conv2d.attrs
self.padding = attrs.padding

pad_value = int(qnn_conv2d.args[QConv2DArgs.IFM_ZERO_POINT.value].data.asnumpy())
self.padding = self.extract_padding(attrs.padding, separate_padding, pad_value)

self.strides = attrs.strides
self.dilation = attrs.dilation
self.activation = activation
Expand All @@ -250,6 +261,37 @@ def __init__(self, func_body: tvm.relay.Function):
if self.groups == self.weights.shape[channels_axis[self.kernel_layout]]:
self.is_depthwise = True

@staticmethod
def extract_padding(
operator_padding: Tuple[int, int, int, int],
separate_padding: relay.Call,
pad_value: int,
) -> Optional[Tuple[int, int, int, int]]:
"""
Convolution operations can sometimes have padding represented as a separate
padding operation before the convolution operation itself. Here we can check
whether these representations can be combined into a single padding attribute
as part of the NPU convolution itself. If the padding specified by the separate
nn.pad operation is not supported, None will be returned. This will cause the
nn.pad to be offloaded separately.
"""
if separate_padding is None:
return operator_padding
if pad_value != int(separate_padding.args[1].data.asnumpy()):
return None
pad_width = separate_padding.attrs["pad_width"]
if len(pad_width) != 4:
return None
if list(pad_width[0]) != [0, 0] or list(pad_width[3]) != [0, 0]:
return None
top, left, bottom, right = operator_padding
return [
top + pad_width[1][0],
left + pad_width[2][0],
bottom + pad_width[1][1],
right + pad_width[2][1],
]

def is_valid(self) -> bool:
"""
This function checks whether QnnConv2D has compatible attributes with the NPU
Expand All @@ -267,7 +309,7 @@ def is_valid(self) -> bool:
return False
if not check_dilation(self.dilation):
return False
if not check_padding(self.padding, self.padding_bounds):
if not self.padding or not check_padding(self.padding, self.padding_bounds):
return False
legal_groups = [1, self.ofm.shape[3]]
if self.groups not in legal_groups:
Expand Down Expand Up @@ -437,7 +479,7 @@ def is_valid(self):
return False
if not check_dilation(self.dilation):
return False
if not check_padding(self.padding, self.padding_bounds):
if not self.padding or not check_padding(self.padding, self.padding_bounds):
return False
if self.weights.layout != "HWOI":
return False
Expand All @@ -453,8 +495,14 @@ def qnn_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for qnn.conv2D with optional fused RELU activation.
"""
optional_pad = is_op("nn.pad")(wildcard(), is_constant())
qnn_conv2d = is_op("qnn.conv2d")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
optional_pad | wildcard(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
).has_attr({"kernel_layout": "HWIO"})
bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
req = is_op("qnn.requantize")(
Expand All @@ -468,8 +516,14 @@ def qnn_depthwise_conv2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
"""
This function creates the pattern for depthwise qnn.conv2D with optional fused RELU activation.
"""
optional_pad = is_op("nn.pad")(wildcard(), is_constant())
qnn_conv2d = is_op("qnn.conv2d")(
wildcard(), is_constant(), is_constant(), is_constant(), is_constant(), is_constant()
optional_pad | wildcard(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
).has_attr({"kernel_layout": "HWOI"})
bias_add = is_op("nn.bias_add")(qnn_conv2d, is_constant())
req = is_op("qnn.requantize")(
Expand Down
11 changes: 9 additions & 2 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,17 @@ def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation=[1, 1]
assert len(strides) == 2
assert len(dilation) == 2
assert len(kernel_shape) == 2
if padding.lower() == "valid":
if isinstance(padding, tuple):
h = (
ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0] + padding[0] + padding[2]
) // strides[0]
w = (
ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1] + padding[1] + padding[3]
) // strides[1]
elif padding.lower() == "valid":
h = math.ceil((ifm_shape[1] - (kernel_shape[0] - 1) * dilation[0]) / strides[0])
w = math.ceil((ifm_shape[2] - (kernel_shape[1] - 1) * dilation[1]) / strides[1])
if padding.lower() == "same":
elif padding.lower() == "same":
h = math.ceil(ifm_shape[1] / strides[0])
w = math.ceil(ifm_shape[2] / strides[1])
ofm_shape = [ifm_shape[0], h, w, ifm_shape[3]]
Expand Down
68 changes: 64 additions & 4 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,43 @@ def conv2d(x):
padding=padding,
dilations=dilation,
)
if activation:
if activation == "RELU":
op = tf.nn.relu(op)
return op

infra.compare_tvm_with_tflite(conv2d, [ifm_shape], accel_type)


def test_tflite_conv2d_with_separate_pad():
np.random.seed(0)

ifm_shape = (1, 55, 34, 3)
kernel_shape = (3, 2)
strides = (1, 1)
dilation = (2, 1)
padding = (0, 0, 1, 1)

@tf.function
def conv2d(x):
tf_strides = [1, strides[0], strides[1], 1]
op = tf.pad(
x,
[[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
"CONSTANT",
)
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 3]
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
return tf.nn.conv2d(
op,
weight,
strides=tf_strides,
padding="VALID",
dilations=dilation,
)

infra.compare_tvm_with_tflite(conv2d, [ifm_shape], "ethos-u55-256")


@pytest.mark.parametrize("ifm_shape", [(1, 214, 227, 2), (1, 27, 42, 3)])
@pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)])
@pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))])
Expand Down Expand Up @@ -116,7 +146,7 @@ def conv2d_double(x):
padding=padding,
dilations=dilation,
)
if activation:
if activation == "RELU":
op2 = tf.nn.relu(op2)
return op2

Expand Down Expand Up @@ -152,7 +182,7 @@ def conv_invalid_scale(x):
padding=padding,
dilations=dilation,
)
if activation:
if activation == "RELU":
op = tf.nn.relu(op)
return op

Expand Down Expand Up @@ -187,13 +217,43 @@ def depthwise_conv2d(x):
op = tf.nn.depthwise_conv2d(
x, weight, strides=tf_strides, padding=padding, dilations=dilation
)
if activation_function:
if activation_function == "RELU":
op = tf.nn.relu(op)
return op

infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], accel_type)


def test_tflite_depthwise_conv2d_with_separate_pad():
np.random.seed(0)

ifm_shape = (1, 23, 32, 7)
kernel_shape = (1, 2)
strides = (3, 2)
dilation = (1, 1)
padding = (0, 0, 1, 1)

@tf.function
def depthwise_conv2d(x):
tf_strides = [1, strides[0], strides[1], 1]
op = tf.pad(
x,
[[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]],
"CONSTANT",
)
weight_shape = [kernel_shape[0], kernel_shape[1], ifm_shape[3], 1]
weight = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32)
return tf.nn.depthwise_conv2d(
op,
weight,
strides=tf_strides,
padding="VALID",
dilations=dilation,
)

infra.compare_tvm_with_tflite(depthwise_conv2d, [ifm_shape], "ethos-u55-256")


@pytest.mark.parametrize(
"accel_type",
ACCEL_TYPES,
Expand Down
Loading