From bb67e2995a215512e03285a28058eb6096b0c937 Mon Sep 17 00:00:00 2001 From: Ryan Date: Wed, 9 Nov 2022 21:22:24 -0600 Subject: [PATCH] [Relay] Refactor constant folding over expr into a utility function --- .../backend/contrib/constant_transforms.cc | 10 +--- .../backend/contrib/constant_transforms.h | 9 --- .../contrib/ethosn/convert_equivalent.cc | 4 +- src/relay/quantize/realize.cc | 18 ++---- src/relay/quantize/realize.h | 2 - src/relay/transforms/fold_constant.cc | 20 +++---- src/relay/transforms/fold_constant.h | 55 +++++++++++++++++++ src/relay/transforms/simplify_expr.cc | 11 +--- 8 files changed, 78 insertions(+), 51 deletions(-) create mode 100644 src/relay/transforms/fold_constant.h diff --git a/src/relay/backend/contrib/constant_transforms.cc b/src/relay/backend/contrib/constant_transforms.cc index 6041d37451aa..45669b5ef271 100644 --- a/src/relay/backend/contrib/constant_transforms.cc +++ b/src/relay/backend/contrib/constant_transforms.cc @@ -21,6 +21,7 @@ #include +#include "../../transforms/fold_constant.h" #include "../../transforms/pattern_utils.h" #include "../../transforms/simplify_expr.h" @@ -33,13 +34,6 @@ namespace tvm { namespace relay { namespace contrib { -Expr FoldConstantExpr(const Expr& expr, bool fold_qnn) { - auto mod = IRModule::FromExpr(expr); - mod = transform::FoldConstant(fold_qnn)(mod); - auto entry_func = Downcast(mod->Lookup("main")); - return expr.as() == nullptr ? entry_func->body : entry_func; -} - Constant TransposeWeights(const Constant& data, const std::string& source_layout, const std::string& target_layout) { Array transpose_matrix; @@ -48,7 +42,7 @@ Constant TransposeWeights(const Constant& data, const std::string& source_layout transpose_matrix.push_back(pos); } Expr transpose = MakeTranspose(data, transpose_matrix); - transpose = InferType(FoldConstantExpr(transpose)); + transpose = InferType(transform::FoldConstantExpr(transpose)); Constant transposed_data = Downcast(transpose); return transposed_data; } diff --git a/src/relay/backend/contrib/constant_transforms.h b/src/relay/backend/contrib/constant_transforms.h index 39a9dc1d53d4..f642564115b6 100644 --- a/src/relay/backend/contrib/constant_transforms.h +++ b/src/relay/backend/contrib/constant_transforms.h @@ -33,15 +33,6 @@ namespace tvm { namespace relay { namespace contrib { -/*! - * \brief Apply constant folding on an expression. - * - * \param expr The expression to fold. - * \param fold_qnn Whether to fold constants for QNN operations. - * \returns The new folded expression. - */ -Expr FoldConstantExpr(const Expr& expr, bool fold_qnn = true); - /*! *\brief Transpose weights from `source_layout` to `target_layout` * diff --git a/src/relay/backend/contrib/ethosn/convert_equivalent.cc b/src/relay/backend/contrib/ethosn/convert_equivalent.cc index 14d94192c84e..ef8c4a5ef567 100644 --- a/src/relay/backend/contrib/ethosn/convert_equivalent.cc +++ b/src/relay/backend/contrib/ethosn/convert_equivalent.cc @@ -30,9 +30,9 @@ #include #include "../../../qnn/utils.h" +#include "../../../transforms/fold_constant.h" #include "../../../transforms/pattern_utils.h" #include "../../../transforms/simplify_expr.h" -#include "../constant_transforms.h" #include "ethosn_api.h" namespace tvm { @@ -176,7 +176,7 @@ Optional ConvertQnnAddToDepthwise(const Expr& expr) { Expr reshape_bias = MakeReshape(requantize_bias, {channels}); try { - reshape_bias = FoldConstantExpr(reshape_bias); + reshape_bias = transform::FoldConstantExpr(reshape_bias); } catch (tvm::Error& e) { // Conversion produced an invalid op. return NullOpt; diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 720ef25cd33d..3c2f6eb96d6b 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -33,6 +33,7 @@ #include "../op/annotation/annotation.h" #include "../qnn/utils.h" +#include "../transforms/fold_constant.h" #include "./quantize.h" namespace tvm { @@ -154,13 +155,6 @@ Expr QuantizeRealize(const Call& ref_call, const Array& new_args, const Ob return QRealizeIntExpr(round_data, dom_scale, DataType::Float(32)); } -Expr FoldConstantOpt(const Expr& expr) { - auto mod = IRModule::FromExpr(expr); - mod = transform::FoldConstant()(mod); - auto entry_func = Downcast(mod->Lookup("main")); - return expr.as() == nullptr ? entry_func->body : entry_func; -} - RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") .set_attr("FQRealizeRewrite", QuantizeRealize); @@ -184,7 +178,7 @@ Expr Conv2dRealize(const Call& ref_call, const Array& new_args, const Obje Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); - Expr dom_scale = FoldConstantOpt(mul); + Expr dom_scale = FoldConstantExpr(mul); return QRealizeIntExpr(ret, dom_scale, out_dtype); } ICHECK(!new_args[0]->IsInstance() || !new_args[1]->IsInstance()); @@ -218,7 +212,7 @@ Expr Conv1dRealize(const Call& ref_call, const Array& new_args, const Obje Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); - Expr dom_scale = FoldConstantOpt(mul); + Expr dom_scale = FoldConstantExpr(mul); return QRealizeIntExpr(ret, dom_scale, out_dtype); } @@ -247,7 +241,7 @@ Expr DenseRealize(const Call& ref_call, const Array& new_args, const Objec Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); - Expr dom_scale = FoldConstantOpt(mul); + Expr dom_scale = FoldConstantExpr(mul); return QRealizeIntExpr(ret, dom_scale, out_dtype); } @@ -273,7 +267,7 @@ Expr MulRealize(const Call& ref_call, const Array& new_args, const ObjectR Expr ret = ForwardOp(ref_call, {ldata, rdata}); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); - Expr dom_scale = FoldConstantOpt(mul); + Expr dom_scale = FoldConstantExpr(mul); return QRealizeIntExpr(ret, dom_scale, dtype); } ICHECK(!new_args[0]->IsInstance() || !new_args[1]->IsInstance()); @@ -527,7 +521,7 @@ Expr BatchMatmulRealize(const Call& ref_call, const Array& new_args, const Expr ret = Call(ref_call->op, {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale); - Expr dom_scale = FoldConstantOpt(mul); + Expr dom_scale = FoldConstantExpr(mul); return QRealizeIntExpr(ret, dom_scale, out_dtype); } diff --git a/src/relay/quantize/realize.h b/src/relay/quantize/realize.h index 16fdf79b246e..6eba69e9c9b1 100644 --- a/src/relay/quantize/realize.h +++ b/src/relay/quantize/realize.h @@ -69,8 +69,6 @@ class QRealizeIntExpr : public QRealizeExpr { TVM_DEFINE_OBJECT_REF_METHODS(QRealizeIntExpr, QRealizeExpr, QRealizeIntExprNode); }; -Expr FoldConstantOpt(const Expr& expr); - } // namespace quantize } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 9dec840be0a7..aee402836f89 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -418,14 +418,6 @@ class ConstantFolder : public MixedModeMutator { TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(IsComplexConstant); -/*! - * \brief Returns \p expr with any constants expressions evaluated and let-bound constants - * inlined. Returns \p expr unchanged if no change. - * - * CAUTION: The importers rely on this function returning \p expr unchanged to preserve sharing - * from their p.o.v. Furthermore, this function can be called before conversion to ANF so - * we must avoid all recursion. - */ Expr FoldConstantExpr(const Expr& expr, const IRModule& mod, bool fold_qnn) { VLOG_CONTEXT << "FoldConstantExpr"; VLOG(1) << "folding:" << std::endl << PrettyPrint(expr); @@ -434,11 +426,19 @@ Expr FoldConstantExpr(const Expr& expr, const IRModule& mod, bool fold_qnn) { return result; } -TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr").set_body_typed(FoldConstantExpr); +Expr FoldConstantExpr(const Expr& expr, bool fold_qnn) { + auto mod = IRModule::FromExpr(expr); + return FoldConstantExpr(expr, mod, fold_qnn); +} + +TVM_REGISTER_GLOBAL("relay._transform.FoldConstantExpr") + .set_body_typed([](const Expr& expr, const IRModule& mod, bool fold_qnn) { + return FoldConstantExpr(expr, mod, fold_qnn); + }); Pass FoldConstant(bool fold_qnn) { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { + [=](Function f, IRModule m, PassContext /* pc */) { return Downcast(FoldConstantExpr(f, m, fold_qnn)); }; return CreateFunctionPass(pass_func, 2, "FoldConstant", {}); diff --git a/src/relay/transforms/fold_constant.h b/src/relay/transforms/fold_constant.h new file mode 100644 index 000000000000..4f475037d195 --- /dev/null +++ b/src/relay/transforms/fold_constant.h @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file fold_constant.h + * \brief Utility functions for folding constants in expressions. + */ +#ifndef TVM_RELAY_TRANSFORMS_FOLD_CONSTANT_H_ +#define TVM_RELAY_TRANSFORMS_FOLD_CONSTANT_H_ + +#include + +namespace tvm { +namespace relay { +namespace transform { + +/*! + * \brief Apply constant folding on an expression. + * + * \param expr The expression to fold. + * \param fold_qnn Whether to fold constants for QNN operations. + * \returns The new folded expression. + */ +Expr FoldConstantExpr(const Expr& expr, bool fold_qnn = true); + +/*! + * \brief Returns \p expr with any constants expressions evaluated and let-bound constants + * inlined. Returns \p expr unchanged if no change. + * + * CAUTION: The importers rely on this function returning \p expr unchanged to preserve sharing + * from their p.o.v. Furthermore, this function can be called before conversion to ANF so + * we must avoid all recursion. + */ +Expr FoldConstantExpr(const Expr& expr, const IRModule& mod, bool fold_qnn); + +} // namespace transform +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_TRANSFORMS_FOLD_CONSTANT_H_ diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 923a18f7bc93..c64957b5b62a 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -37,6 +37,7 @@ #include #include "../op/tensor/transform.h" +#include "fold_constant.h" #include "pattern_utils.h" namespace tvm { @@ -795,10 +796,7 @@ class SwitchAddMultiply : public DFPatternRewrite { } Expr const_expr = Call(Op::Get("multiply"), {c1, c2}); - IRModule const_mod = IRModule::FromExpr(const_expr); - const_mod = transform::FoldConstant()(const_mod); - GlobalVar const_main = const_mod->GetGlobalVar("main"); - Expr const_val = Downcast(const_mod->functions[const_main])->body; + Expr const_val = transform::FoldConstantExpr(const_expr); return Call(Op::Get("add"), {Call(Op::Get("multiply"), {x, c2}), const_val}); } @@ -833,10 +831,7 @@ class SimplifyAdjacentMultiplyOrAdd : public DFPatternRewrite { } Expr const_expr = Call(call->op, {c1, c2}); - IRModule const_mod = IRModule::FromExpr(const_expr); - const_mod = transform::FoldConstant()(const_mod); - GlobalVar const_main = const_mod->GetGlobalVar("main"); - Expr const_val = Downcast(const_mod->functions[const_main])->body; + Expr const_val = transform::FoldConstantExpr(const_expr); return Call(call->op, {x, const_val}); }