Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/tvm/relay/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ class Op : public relay::Expr {
*/
template <typename ValueType>
inline static OpMap<ValueType> GetAttr(const std::string& attr_name);
/*!
* \brief Checks if an attr is present in the registry.
* \param attr_name The name of the attribute.
* \return bool True if the attr is present.
*/
inline static bool HasAttr(const std::string& attr_name);
/*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
Expand All @@ -171,6 +177,12 @@ class Op : public relay::Expr {
* \return reference to GenericOpMap
*/
TVM_DLL static const GenericOpMap& GetGenericAttr(const std::string& key);
/*!
* \brief Checks if the key is present in the registry
* \param key The attribute key
* \return bool True if the key is present
*/
TVM_DLL static const bool HasGenericAttr(const std::string& key);
};

/*! \brief Helper structure to register operators */
Expand Down Expand Up @@ -393,6 +405,10 @@ inline OpMap<ValueType> Op::GetAttr(const std::string& key) {
return OpMap<ValueType>(Op::GetGenericAttr(key));
}

inline bool Op::HasAttr(const std::string& key) {
return Op::HasGenericAttr(key);
}

inline OpNode* OpRegistry::get() {
return const_cast<OpNode*>(op_.operator->());
}
Expand Down
60 changes: 60 additions & 0 deletions include/tvm/relay/qnn/transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/qnn/transform.h
*
* This file implements a pass manager for QNN ops using Relay Pass manager.
*/
#ifndef TVM_RELAY_QNN_TRANSFORM_H_
#define TVM_RELAY_QNN_TRANSFORM_H_

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/relay/transform.h>

namespace tvm {
namespace relay {

using relay::transform::Pass;

namespace qnn {
namespace transform {

/*!
* \brief Legalizes a QNN expr. Contains specifically two types of Legalizations. First,
* converts/Lowers an expression containing QNN ops to an expression containing only core Relay ops.
* Each QNN op is lowered to a sequence of exisiting Relay ops. This is a target-independent pass.
* One can register the lowering/transformation function for this op using FTVMQnnCanonicalize
* attr_name for FTVMLegalize op attribute. Second, as opposed to Relay Legalize, this one legalizes
* only QNN ops. One can register a transformation/legalization function for an op by using the
* FTVMQnnLegalize attr_name for FTVMLegalize op attribute. The isolation of QNN and Relay Legalize
* gives us separation of concerns, leading to a better software practice. The legalization can be
* configured to happen per target.
*
* \return The pass.
*/
TVM_DLL Pass Legalize();

} // namespace transform

} // namespace qnn
} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_QNN_TRANSFORM_H_
15 changes: 10 additions & 5 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/vm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/qnn/transform.h>
#include <memory>

#include "utils.h"
Expand Down Expand Up @@ -282,6 +283,15 @@ class RelayBuildModule : public runtime::ModuleNode {
const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) {
Array<Pass> pass_seqs;

// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());

// Legalize pass is restricted to homogeneous execution for now.
if (targets.size() == 1) {
pass_seqs.push_back(transform::Legalize());
}

pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
Expand All @@ -304,11 +314,6 @@ class RelayBuildModule : public runtime::ModuleNode {
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());

// Legalize pass is restricted to homogeneous execution for now.
if (targets.size() == 1) {
pass_seqs.push_back(transform::Legalize());
}

// Alter layout transformation is only applied to homogeneous execution yet.
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
Expand Down
11 changes: 11 additions & 0 deletions src/relay/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ const GenericOpMap& Op::GetGenericAttr(const std::string& key) {
return *it->second.get();
}

// Check if a key is present in the registry.
const bool Op::HasGenericAttr(const std::string& key) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex> lock(mgr->mutex);
auto it = mgr->attr.find(key);
if (it == mgr->attr.end()) {
return false;
}
return true;
}

void OpRegistry::UpdateAttr(const std::string& key,
TVMRetValue value,
int plevel) {
Expand Down
58 changes: 33 additions & 25 deletions src/relay/pass/legalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,32 +46,40 @@ class Legalizer : public ExprMutator {
Expr new_e = ExprMutator::VisitExpr_(call_node);
Call new_call = Downcast<Call>(new_e);

// Check if the string is registered in the OpRegistry.
if (!Op::HasAttr(legalize_map_attr_name_)) {
return new_e;
}

// Collect the registered legalize function.
auto fop_legalize = Op::GetAttr<FTVMLegalize>(legalize_map_attr_name_);
Op op = Downcast<Op>(call_node->op);

if (fop_legalize.count(op)) {
// Collect the new_args.
tvm::Array<Expr> call_args = new_call->args;

// Collect input and output dtypes to pass on to Legalize API.
tvm::Array<tvm::relay::Type> types;
for (auto arg : call_node->args) {
types.push_back(arg->checked_type());
}
types.push_back(call_node->checked_type());

// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);

// Reassign new_e if the transformation succeeded.
if (legalized_value.defined()) {
// Check that the returned Expr from legalize is CallNode.
const CallNode* legalized_call_node = legalized_value.as<CallNode>();
CHECK(legalized_call_node)
<< "Can only replace the original operator with another call node";

new_e = legalized_value;
auto call_op = call_node->op;
if (call_op.as<OpNode>()) {
Op op = Downcast<Op>(call_node->op);

if (fop_legalize.count(op)) {
// Collect the new_args.
tvm::Array<Expr> call_args = new_call->args;

// Collect input and output dtypes to pass on to Legalize API.
tvm::Array<tvm::relay::Type> types;
for (auto arg : call_node->args) {
types.push_back(arg->checked_type());
}
types.push_back(call_node->checked_type());

// Transform the op by calling the registered legalize function.
Expr legalized_value = fop_legalize[op](call_node->attrs, call_args, types);

// Reassign new_e if the transformation succeeded.
if (legalized_value.defined()) {
// Check that the returned Expr from legalize is CallNode.
const CallNode* legalized_call_node = legalized_value.as<CallNode>();
CHECK(legalized_call_node)
<< "Can only replace the original operator with another call node";

new_e = legalized_value;
}
}
}

Expand All @@ -95,7 +103,7 @@ Pass Legalize(const std::string& legalize_map_attr_name) {
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(relay::legalize::Legalize(f, legalize_map_attr_name));
};
return CreateFunctionPass(pass_func, 3, "Legalize", {ir::StringImm::make("InferType")});
return CreateFunctionPass(pass_func, 0, "Legalize", {ir::StringImm::make("InferType")});
}

TVM_REGISTER_API("relay._transform.Legalize").set_body_typed(Legalize);
Expand Down
47 changes: 47 additions & 0 deletions src/relay/qnn/pass/legalize.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file relay/qnn/pass/legalize.cc
* \brief The Legalize wrapper for QNN.
*/

#include <tvm/relay/qnn/transform.h>

namespace tvm {
namespace relay {
namespace qnn {

namespace transform {

Pass Legalize() {
Array<Pass> pass_seqs;
pass_seqs.push_back(relay::transform::Legalize("FTVMQnnLegalize"));
pass_seqs.push_back(relay::transform::Legalize("FTVMQnnCanonicalize"));
relay::transform::Pass seq = relay::transform::Sequential(pass_seqs);
return seq;
}

TVM_REGISTER_API("relay.qnn._transform.Legalize").set_body_typed(Legalize);

} // namespace transform

} // namespace qnn
} // namespace relay
} // namespace tvm
1 change: 0 additions & 1 deletion tests/python/relay/test_op_qnn_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def get_qnn_func(data,

mod = relay.Function(relay.analysis.free_vars(func), func)
mod = relay.Module.from_expr(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
return mod

def get_funcs(data_shape,
Expand Down
1 change: 0 additions & 1 deletion tests/python/relay/test_op_qnn_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def quantize_test_driver(in_dtype, quant_args, in_data, verify_output_data):
input_zero_point=input_zero_point)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
Expand Down
1 change: 0 additions & 1 deletion tests/python/relay/test_op_qnn_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def quantize_test_driver(in_dtype, quant_args, out_dtype, in_data, verify_output
output_zero_point=output_zero_point,out_dtype=out_dtype)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = relay.Module.from_expr(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, "llvm", params=None)
rt_mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0))
Expand Down
1 change: 0 additions & 1 deletion tests/python/relay/test_op_qnn_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def get_mod(data_shape, data_dtype, out_dtype, input_scale, output_scale,

mod = relay.Function(relay.analysis.free_vars(mod), mod)
mod = relay.Module.from_expr(mod)
mod = relay.qnn.transform.CanonicalizeOps()(mod)
return mod

def same_scale_test():
Expand Down