Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def relay_to_tir(mod: tvm.ir.IRModule) -> tvm.ir.IRModule:
mod = OutlineCompilerFunctions("ethos-u")(mod)
mod = LegalizeEthosU()(mod)
mod = LUTsOptimizer()(mod)
mod = relay.transform.InferType()(mod)
mod = IdentityOptimizer()(mod)
mod = LayoutOptimizer()(mod)
mod = relay.transform.InferType()(mod)
Expand Down
61 changes: 54 additions & 7 deletions src/relay/backend/contrib/ethosu/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,24 +115,33 @@ class RemoveRedundantIdentities : public MixedModeMutator {
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
Call call = Downcast<Call>(post);

// only consider rewrite if current op is an NPU compute op.
// don't consider rewrite if current op is an identity or concatenate.
if (!call->op->IsInstance<OpNode>()) {
return post;
}
const auto* op = call->op.as<OpNode>();
std::string op_name = op->name;
if (op_name.substr(0, 15) != "contrib.ethosu." || op_name == "contrib.ethosu.identity") {
if (op_name == "contrib.ethosu.identity" || op_name == "concatenate") {
return post;
}

// check if we can rewrite parent identity operations to current call.
bool needs_rewrite = false;
Array<Expr> new_args;
for (const auto& arg : call->args) {
if (const auto* parent_callnode = arg.as<CallNode>()) {
Expr current_arg = arg;

// expand tuple to get parent op if we run into one - nested tuples are not supported.
if (const auto* tuple_get_item = arg.as<TupleGetItemNode>()) {
const auto* tuple = tuple_get_item->tuple.as<TupleNode>();
current_arg = tuple->fields[tuple_get_item->index];
}

if (const auto* parent_callnode = current_arg.as<CallNode>()) {
if (const auto* parent_op = parent_callnode->op.as<OpNode>()) {
Call parent_call = GetRef<Call>(parent_callnode);
if (parent_op->name == "contrib.ethosu.identity" && IdentityDoesNothing(parent_call)) {
if (parent_op->name == "contrib.ethosu.identity" && IdentityDoesNothing(parent_call) &&
CheckIdentityBetweenTransformOperations(call, parent_call)) {
needs_rewrite = true;
new_args.push_back(parent_call->args[0]);
continue;
Expand All @@ -143,7 +152,10 @@ class RemoveRedundantIdentities : public MixedModeMutator {
}

if (needs_rewrite) {
return Call(call->op, new_args, call->attrs, call->type_args);
Call new_call = Call(call->op, new_args, call->attrs, call->type_args);
// since we are only removing an identity, we know the type information has not changed
new_call->checked_type_ = call->checked_type_;
return new_call;
}
return post;
}
Expand All @@ -156,6 +168,41 @@ class RemoveRedundantIdentities : public MixedModeMutator {
bool has_no_activation = attrs->activation == "NONE";
return does_not_requantize && has_no_activation;
}

bool CheckIdentityBetweenTransformOperations(const Call& call, const Call& identity_call) {
const auto* op = call->op.as<OpNode>();
std::vector<std::string> nc_ops = {"reshape", "strided_slice"};

if (op && (std::find(nc_ops.begin(), nc_ops.end(), op->name) != nc_ops.end())) {
// check if the parent to identity operation is also a non-compute operation,
// if it isn't we can safely remove the identity in question by returning true.
const auto* identity_arg = identity_call->args[0].as<CallNode>();
if (!identity_arg) {
return true;
}
const auto* identity_arg_op = identity_arg->op.as<OpNode>();
if (!identity_arg_op ||
!(std::find(nc_ops.begin(), nc_ops.end(), identity_arg_op->name) != nc_ops.end())) {
return true;
}

const auto* call_tt = call->checked_type_.as<TensorTypeNode>();
const auto* identity_arg_tt = identity_arg->checked_type_.as<TensorTypeNode>();
CHECK(call_tt && identity_arg_tt)
<< "InferType should be run before RemoveRedundantIdentities";

// we can only remove the identity operation if the second non-compute operation
// in the sequence does not reduce the dimensionality of the output to the first
// non-compute operation. Doing so could lead to data being accessed incorrectly
// by the subsequent compute operation due to the reduction in dimensionality.
size_t first_transform_op_dims = identity_arg_tt->shape.size();
size_t second_transform_op_dims = call_tt->shape.size();
if (second_transform_op_dims < first_transform_op_dims) {
return false;
}
}
return true;
}
};

/*!
Expand All @@ -177,8 +224,8 @@ tvm::transform::Pass IdentityOptimizer() {
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0,
"relay.backend.contrib.ethos-u.IdentityOptimizer", {});
return tvm::transform::CreateModulePass(
pass_func, 0, "relay.backend.contrib.ethos-u.IdentityOptimizer", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay.ext.ethos-u.IdentityOptimizer").set_body_typed(IdentityOptimizer);
Expand Down
47 changes: 42 additions & 5 deletions tests/python/contrib/test_ethosu/test_identity_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,14 @@ def test_many_output_identity():
def get_graph(get_expected=False):
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
x = relay.reshape(x, newshape=(1, 1, 4, 4))
identity = infra.make_ethosu_identity(x)
if not get_expected:
x = infra.make_ethosu_identity(x)
outputs = []
for _ in range(4):
ifm = x if get_expected else identity
outputs.append(infra.make_ethosu_unary_elementwise(ifm, 4, "ABS"))
outputs.append(relay.strided_slice(identity, begin=(0, 0, 0, 0), end=(1, 1, 4, 4)))
outputs.append(infra.make_ethosu_unary_elementwise(x, 4, "ABS"))
ss = relay.strided_slice(x, begin=(0, 0, 0, 0), end=(1, 1, 4, 4))
identity_2 = infra.make_ethosu_identity(ss)
outputs.append(identity_2)
out = relay.concatenate(outputs, axis=0)
return relay.Function(relay.analysis.free_vars(out), out)

Expand Down Expand Up @@ -220,7 +222,8 @@ def test_identity_removal_with_multiple_transform_ops():
def get_graph(get_expected=False):
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
x = relay.strided_slice(x, begin=[0, 0, 0, 0], end=[1, 2, 2, 2])
x = infra.make_ethosu_identity(x)
if not get_expected:
x = infra.make_ethosu_identity(x)
x = relay.reshape(x, newshape=(1, 1, 1, 8))
if not get_expected:
x = infra.make_ethosu_identity(x)
Expand Down Expand Up @@ -267,6 +270,25 @@ def get_graph(get_expected=False):
_assert_structural_equal(actual, expected)


def test_multiple_transform_ops_with_reduction_in_dimensionality():
"""Removal of an identity operation between two transform operations is usually okay.
However, if the dimensionality of the input is reduced by the second transformation
operation, it can lead to an output mismatch. Checking that the pass doesn't remove
an identity given this case."""

def get_graph():
x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8")
x = relay.strided_slice(x, begin=(0, 0, 0, 0), end=(1, 2, 2, 2))
x = infra.make_ethosu_identity(x)
x = relay.reshape(x, newshape=(1, 2, 4))
x = infra.make_ethosu_identity(x)
return relay.Function(relay.analysis.free_vars(x), x)

actual = _optimize(get_graph())
expected = _optimize(get_graph(), optimize=False)
_assert_structural_equal(actual, expected)


def test_identity_optimizer_runs_in_compilation_pipeline():
"""Checks that the identity optimization pass is run as part of the NPU compilation pipeline."""

Expand Down Expand Up @@ -320,3 +342,18 @@ def model(x):
return y

_compare_tvm_with_tflite(model, [ifm_shape], "ethos-u55-256")


def test_multiple_transform_ops_same_output():
"""Check case of identity removal between transform ops and
then without, making sure they have the same output."""
ifm_shape = (1, 2, 2, 4)

@tf.function
def model(x):
x = tf.reshape(x, (1, 1, 4, 4))
x = tf.slice(x, (0, 0, 0, 0), (1, 1, 4, 3))
x = tf.reshape(x, (12,))
return x

_compare_tvm_with_tflite(model, [ifm_shape], "ethos-u55-256")