Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relay/op/contrib/_ethosn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
import tvm._ffi

tvm._ffi._init_api("relay.ethos-n.support", __name__)
tvm._ffi._init_api("relay.backend.contrib.ethos-n", __name__)
80 changes: 62 additions & 18 deletions python/tvm/relay/op/contrib/ethosn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tvm.relay.build_module import bind_params_by_name

from ...dataflow_pattern import is_constant, is_op, wildcard
from . import _ethosn as support
from . import _ethosn
from .register import register_pattern_table


Expand Down Expand Up @@ -60,6 +60,18 @@ def ethosn_api_version() -> str:
return tvm.get_global_func("relay.ethos-n.api.version")()


def ConvertEquivalents() -> tvm.ir.IRModule: # pylint: disable=invalid-name
"""Converts operations into a numerically equivalent form
that can be understood by the NPU codegen.

Return
------
Pass
The module pass.
"""
return _ethosn.ConvertEquivalents()


def partition_for_ethosn(mod, params=None, **opts):
"""Partition the graph greedily offloading supported
operators to Arm Ethos-N NPU.
Expand Down Expand Up @@ -107,9 +119,9 @@ def partition_for_ethosn(mod, params=None, **opts):
transform.AnnotateTarget("ethos-n"),
transform.MergeCompilerRegions(),
transform.PartitionGraph(),
ConvertEquivalents(),
]
)

return seq(mod)


Expand Down Expand Up @@ -183,70 +195,102 @@ def qnn_resize_pattern():
)
return pattern

def qnn_mul_pattern():
"""
Multiply is supported when one input is a constant of shape [1, ..., C],
where C matches the number of channels of the other input.
"""
mul_op = is_op("qnn.mul")
gen_mul_inputs = lambda x, y: mul_op(
x,
y,
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
is_constant(),
)
input_is_left = gen_mul_inputs(wildcard(), is_constant())
input_is_right = gen_mul_inputs(is_constant(), wildcard())
return input_is_left | input_is_right

def check_conv2d(extract):
"""Check if a conv2d is supported by Ethos-N."""
if not ethosn_available():
return False

return support.conv2d(extract)
return _ethosn.conv2d(extract)

def check_fc(extract):
"""Check if a fully connected is supported by Ethos-N."""
if not ethosn_available():
return False

return support.fc(extract)
return _ethosn.fc(extract)

def check_avg_pool2d(extract):
"""Check if a avg pool2d is supported by Ethos-N."""
if not ethosn_available():
return False

return support.avg_pool2d(extract)
return _ethosn.avg_pool2d(extract)

def check_mean(extract):
"""Check if mean is supported by Ethos-N."""
if not ethosn_available():
return False

return support.mean(extract)
return _ethosn.mean(extract)

def check_sigmoid(extract):
"""Check if a sigmoid is supported by Ethos-N."""
if not ethosn_available():
return False

return support.sigmoid(extract)
return _ethosn.sigmoid(extract)

def check_tanh(extract):
"""Check if tanh is supported by Ethos-N."""
if not ethosn_available():
return False

return support.tanh(extract)
return _ethosn.tanh(extract)

def check_leaky_relu(extract):
"""Check if Leaky ReLU is supported."""
if not ethosn_available():
return False

return support.leaky_relu(extract)
return _ethosn.leaky_relu(extract)

def check_mul(extract):
"""Check if Mul is supported."""
if not ethosn_available():
return False
# Do not support scalar constants for now
check_scalar = lambda i: isinstance(i, tvm.relay.Constant) and len(i.data.shape) == 0
if check_scalar(extract.args[0]) or check_scalar(extract.args[1]):
return False
extract = _ethosn.ConvertQnnMultiply(extract)
return _ethosn.conv2d(extract)

def check_requantize(extract):
"""Check if requantize is supported."""
if not ethosn_available():
return False

return support.requantize(extract)
return _ethosn.requantize(extract)

def check_resize(extract):
"""Check if resize (nearest neighbor) is supported."""
if not ethosn_available():
return False

return support.resize(extract)
return _ethosn.resize(extract)

return [
("ethos-n.qnn_mul", qnn_mul_pattern(), check_mul),
("ethos-n.qnn_conv2d", qnn_conv_pattern(), check_conv2d),
("ethos-n.qnn_avg_pool2d", qnn_avg_pool2d_pattern(), check_avg_pool2d),
("ethos-n.qnn_sigmoid", qnn_sigmoid_pattern(), check_sigmoid),
Expand Down Expand Up @@ -274,7 +318,7 @@ def max_pool2d(expr):
if not ethosn_available():
return False

return support.max_pool2d(expr)
return _ethosn.max_pool2d(expr)


@tvm.ir.register_op_attr("reshape", "target.ethos-n")
Expand All @@ -285,7 +329,7 @@ def reshape(expr):
if not _is_ethosn_composite(expr.args[0]):
return False

return support.reshape(expr)
return _ethosn.reshape(expr)


@tvm.ir.register_op_attr("qnn.add", "target.ethos-n")
Expand All @@ -294,15 +338,15 @@ def qnn_add(expr):
if not ethosn_available():
return False

return support.addition(expr)
return _ethosn.addition(expr)


@tvm.ir.register_op_attr("qnn.concatenate", "target.ethos-n")
def qnn_concatenate(expr):
"""Check if a concatenate is supported by Ethos-N."""
if not ethosn_available():
return False
if not support.concatenate(expr):
if not _ethosn.concatenate(expr):
return False

# Support library has some unenforced restrictions on qnn params
Expand Down Expand Up @@ -332,7 +376,7 @@ def split(expr):
return False
if ethosn_api_version() >= LooseVersion("3.0.1"):
return False
if not support.split(expr):
if not _ethosn.split(expr):
return False

return True
Expand All @@ -343,7 +387,7 @@ def depth_to_space(expr):
"""Check if a depth_to_space is supported by Ethos-N."""
if not ethosn_available():
return False
if not support.depth_to_space(expr):
if not _ethosn.depth_to_space(expr):
return False

return True
Expand All @@ -354,7 +398,7 @@ def clip(expr):
"""Check if a clip is supported by Ethos-N."""
if not ethosn_available():
return False
if not support.relu(expr):
if not _ethosn.relu(expr):
return False

return True
144 changes: 144 additions & 0 deletions src/relay/backend/contrib/ethosn/convert_equivalent.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/backend/contrib/ethosn/convert_equivalent.cc
* \brief Converts operations into a numerically equivalent form
* that can be understood by the NPU codegen.
*/

#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>

#include <unordered_map>

#include "../../../qnn/utils.h"
#include "../../../transforms/pattern_utils.h"
#include "../../../transforms/simplify_expr.h"

namespace tvm {
namespace relay {
namespace contrib {
namespace ethosn {

/*!
* \brief Converts qnn.mul to mathematically equivalent
* qnn.conv2d depthwise operation.
*/
Expr ConvertQnnMultiply(const Expr& expr) {
Call call = Downcast<Call>(expr);

Expr input1 = call->args[0];
Expr input2 = call->args[1];
Expr input1_scale = call->args[2];
Expr input1_zero_point = call->args[3];
Expr input2_scale = call->args[4];
Expr input2_zero_point = call->args[5];
// Reverse the inputs if the constant is first input
if (call->args[0]->IsInstance<ConstantNode>()) {
input1 = call->args[1];
input2 = call->args[0];
input1_scale = call->args[4];
input1_zero_point = call->args[5];
input2_scale = call->args[2];
input2_zero_point = call->args[3];
}
Expr output_scale = call->args[6];
Expr output_zero_point = call->args[7];

const auto* input_constant = input2.as<ConstantNode>();
ICHECK(input_constant) << "Expected ConstantNode but got " << input2->GetTypeKey();
const auto* input_constant_tt = input_constant->checked_type().as<TensorTypeNode>();
int channels = input_constant_tt->shape.back().as<IntImmNode>()->value;

runtime::NDArray input_data = input_constant->data;
runtime::NDArray kernel_data_hwoi =
runtime::NDArray::Empty({1, 1, channels, 1}, input_data->dtype, input_data->device);
kernel_data_hwoi.CopyFrom(input_data);
Constant kernel = Constant(kernel_data_hwoi, input_constant->span);

Type output_type = expr->checked_type();
auto output_tt = output_type.as<TensorTypeNode>();
ICHECK(output_tt) << "Expected TensorTypeNode but got " << output_type->GetTypeKey();
DataType output_dtype = output_tt->dtype;

Expr conv2d = qnn::MakeQnnConv2D(
input1, kernel, input1_zero_point, input2_zero_point, input1_scale, input2_scale, {1, 1},
{0, 0, 0, 0}, {1, 1}, channels, channels, {1, 1}, "NHWC", "HWOI", "NHWC", DataType::Int(32));
Constant bias_data = MakeConstantZeros(DataType::Int(32), {channels});
Expr bias_add = MakeBiasAdd(conv2d, bias_data, 3);
Expr requantize = qnn::MakeRequantize(bias_add, input1_scale, input1_zero_point, output_scale,
output_zero_point, -1, "None", "None", output_dtype);

return InferType(requantize);
}

TVM_REGISTER_GLOBAL("relay.backend.contrib.ethos-n.ConvertQnnMultiply")
.set_body_typed(ConvertQnnMultiply);

class ConvertEquivalentsMutator : public MixedModeMutator {
public:
Expr Rewrite_(const CallNode* pre, const Expr& post) override {
Call call = Downcast<Call>(post);
if (!call->op->IsInstance<FunctionNode>()) {
return post;
}

Function func = Downcast<Function>(call->op);
Function new_func = Function(func);
auto composite_name = func->GetAttr<String>(attr::kComposite);
if (composite_name == "ethos-n.qnn_mul") {
Expr new_func_body = ConvertQnnMultiply(func->body);
new_func = WithFields(func, func->params, new_func_body);
new_func = WithAttr(std::move(new_func), attr::kComposite, String("ethos-n.qnn_conv2d"));
}

Call new_call = WithFields(call, new_func);
return Downcast<Expr>(new_call);
}
};

tvm::transform::Pass ConvertEquivalents() {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[=](IRModule mod, transform::PassContext ctx) {
for (auto gv : mod->GetGlobalVars()) {
Function func = Downcast<Function>(mod->Lookup(gv));
auto compiler_name = func->GetAttr<String>(attr::kCompiler);
if (compiler_name.defined() && compiler_name == "ethos-n") {
auto new_body = ConvertEquivalentsMutator().VisitExpr(func->body);
if (!new_body.same_as(func->body)) {
Function new_func = WithFields(func, func->params, new_body);
mod->Update(gv, new_func);
}
}
}
return mod;
};
return tvm::transform::CreateModulePass(
pass_func, 0, "relay.backend.contrib.ethos-n.ConvertEquivalents", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay.backend.contrib.ethos-n.ConvertEquivalents")
.set_body_typed(ConvertEquivalents);

} // namespace ethosn
} // namespace contrib
} // namespace relay
} // namespace tvm
2 changes: 2 additions & 0 deletions src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ Expr MakeShapeOf(Expr data, DataType dtype);

Expr MakeTake(Expr data, Expr indices, Integer batch_dims, Integer axis, String mode);

Expr MakeBiasAdd(Expr data, Expr bias, int axis);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_MAKE_OP_H_
4 changes: 4 additions & 0 deletions src/relay/qnn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_sh
attrs.operator->(), input_shape, attrs->out_dtype);
}

Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale,
Expr output_zero_point, int axis, String rounding, String compute_dtype,
DataType out_dtype);

Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
const Expr& input_zero_point, const Array<tvm::relay::Type>& types,
const DequantizeAttrs* attrs);
Expand Down
Loading