Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,70 @@ def __call__(self, *args, **kwargs):
pass


class ExpandDimsRewriter(DFPatternCallback):
"""Legalize expand dims to a reshape operator."""

def __init__(self):
super().__init__(require_type=True, rewrite_once=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.ExpandDimsParams.composite_name})
)(None)

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = ethosu_patterns.ExpandDimsParams(post.op.body)
return relay.op.reshape(post.args[0], newshape=params.output.shape)


@ir.transform.module_pass(opt_level=1)
class LegalizeExpandDims:
"""This is the pass that wraps ExpandDimsRewriter."""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(ExpandDimsRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


class SqueezeRewriter(DFPatternCallback):
"""Legalize squeeze to a reshape operator."""

def __init__(self):
super().__init__(require_type=True, rewrite_once=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.SqueezeParams.composite_name})
)(None)

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
params = ethosu_patterns.SqueezeParams(post.op.body)
return relay.op.reshape(post.args[0], newshape=params.output.shape)


@ir.transform.module_pass(opt_level=1)
class LegalizeSqueeze:
"""This is the pass that wraps SqueezeRewriter."""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(SqueezeRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand Down Expand Up @@ -1477,6 +1541,8 @@ def transform_module(
mod = LegalizeSigmoid()(mod)
mod = LegalizeRequantize()(mod)
mod = LegalizeResize2d()(mod)
mod = LegalizeExpandDims()(mod)
mod = LegalizeSqueeze()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
Expand Down
84 changes: 78 additions & 6 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,19 +1214,22 @@ class ConcatParams:

def __init__(self, func_body):
self.concat = func_body
self.is_qnn_variant = self.concat.op.name == "qnn.concatenate"
self.input_tensors = [TensorParams(tensor) for tensor in list(func_body.args[0])]
self.input_scales = [s.data.asnumpy() for s in list(func_body.args[1])]
self.input_zero_points = [zp.data.asnumpy() for zp in list(func_body.args[2])]
self.axis = func_body.attrs.axis

if self.is_qnn_variant:
self.input_scales = [s.data.asnumpy() for s in list(func_body.args[1])]
self.input_zero_points = [zp.data.asnumpy() for zp in list(func_body.args[2])]

def is_valid(self):
"""Checks whether Concatenate has compatible attributes with the hardware"""
if not check_valid_dtypes(self.input_tensors, supported_dtypes=[np.int8]):
return False
# Check that the scales and zero points of input tensors are the same
if not all(self.input_scales == self.input_scales[0]):
if self.is_qnn_variant and not all(self.input_scales == self.input_scales[0]):
return False
if not all(self.input_zero_points == self.input_zero_points[0]):
if self.is_qnn_variant and not all(self.input_zero_points == self.input_zero_points[0]):
return False

input_dim = len(self.input_tensors[0].shape)
Expand All @@ -1244,6 +1247,8 @@ def is_valid(self):
output_shape = self.concat.checked_type.shape
if len(output_shape) != input_dim:
return False
if len(output_shape) > 3 and output_shape[0] != 1:
return False
return True


Expand All @@ -1252,8 +1257,11 @@ def concat_pattern():
tensors = is_tuple(None)
scales = is_tuple(None)
zero_points = is_tuple(None)
concat = is_op("qnn.concatenate")(tensors, scales, zero_points, is_constant(), is_constant())
return concat
qnn_concat = is_op("qnn.concatenate")(
tensors, scales, zero_points, is_constant(), is_constant()
)
concat = is_op("concatenate")(tensors)
return concat | qnn_concat


class SplitParams:
Expand Down Expand Up @@ -1433,6 +1441,60 @@ def resize2d_pattern() -> tvm.relay.dataflow_pattern.DFPattern:
return quant | is_op("image.resize2d")(wildcard()).has_attr({"method": "nearest_neighbor"})


class ExpandDimsParams:
"""
This class will parse a call to a ethos-u.expand_dims composite function
and extract the parameter information.
"""

composite_name = "ethos-u.expand_dims"

def __init__(self, func_body):
self.expand_dims = func_body
self.input = TensorParams(func_body.args[0])
self.output = TensorParams(func_body)

def is_valid(self):
"""Checks whether expand_dims has compatible attributes with the hardware."""
if not check_dimensions(self.input) or not check_dimensions(self.output):
return False
if not check_valid_dtypes([self.input, self.output], supported_dtypes=[np.int8]):
return False
return True


def expand_dims_pattern():
"""Create the pattern for expand_dims."""
return is_op("expand_dims")(wildcard())


class SqueezeParams:
"""
This class will parse a call to a ethos-u.squeeze composite function
and extract the parameter information.
"""

composite_name = "ethos-u.squeeze"

def __init__(self, func_body):
self.squeeze = func_body
self.input = TensorParams(func_body.args[0])
self.output = TensorParams(func_body)

def is_valid(self):
"""Checks whether squeeze has compatible attributes with the hardware."""
if not check_dimensions(self.output):
return False
if not check_valid_dtypes([self.input, self.output], supported_dtypes=[np.int8]):
return False
return True


def squeeze_pattern():
"""Create the pattern for squeeze."""
return is_op("squeeze")(wildcard())


@register_pattern_table("ethos-u")
def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]:
return [
Expand Down Expand Up @@ -1533,6 +1595,16 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
resize2d_pattern(),
lambda pat: Resize2dParams(pat).is_valid(),
),
(
ExpandDimsParams.composite_name,
expand_dims_pattern(),
lambda pat: ExpandDimsParams(pat).is_valid(),
),
(
SqueezeParams.composite_name,
squeeze_pattern(),
lambda pat: SqueezeParams(pat).is_valid(),
),
]


Expand Down
56 changes: 56 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,28 @@ def create_model():
_compare_ethosu_with_reference(ethosu_mod, input_data, output_data, accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize("ifm_shape,axis", [((2,), 0), ((1, 3, 3), 2)])
def test_tflite_expand_dims(accel_type, ifm_shape, axis):
@tf.function
def expand_dims_func(x):
return tf.expand_dims(x, axis=axis)

_compare_tvm_with_tflite(expand_dims_func, [ifm_shape], accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape,axis", [((1, 1, 2, 1), 0), ((1, 3, 3, 1), 3), ((1, 1, 2, 1), None)]
)
def test_tflite_squeeze(accel_type, ifm_shape, axis):
@tf.function
def squeeze_func(x):
return tf.squeeze(x, axis=axis)

_compare_tvm_with_tflite(squeeze_func, [ifm_shape], accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape,size",
Expand Down Expand Up @@ -1100,5 +1122,39 @@ def conv2d_transpose(x):
_compare_tvm_with_tflite(conv2d_transpose, [ifm_shape], accel_type=accel_type)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shapes,axis",
[
([(1, 2, 2), (1, 2, 2), (1, 2, 2)], 2),
([(5, 4), (5, 4)], 1),
([(1,), (1,)], 0),
([(3, 1), (3, 1), (3, 1), (3, 1)], 0),
],
)
def test_tflite_pack(accel_type, ifm_shapes, axis):
@tf.function
def pack_func(*inputs):
return tf.stack(inputs, axis=axis)

# TODO(lhutton1) For now output is not bit exact with TFLite.
# This is because TFLite reference kernels are not being used.
# For this, TFLite will need upgrading to 2.6.
_compare_tvm_with_tflite(pack_func, ifm_shapes, accel_type, output_tolerance=1)


@pytest.mark.parametrize("accel_type", ACCEL_TYPES)
@pytest.mark.parametrize(
"ifm_shape,axis",
[[(1, 2, 3, 4), 1], [(2, 3), 1], [(5, 6, 7), 2]],
)
def test_tflite_unpack(accel_type, ifm_shape, axis):
@tf.function
def unpack_func(x):
return tf.unstack(x, axis=axis)

_compare_tvm_with_tflite(unpack_func, [ifm_shape], accel_type)


if __name__ == "__main__":
pytest.main([__file__])
Loading