From 930f93614a14c367f752eb6425414dca30a5aba6 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 24 Mar 2023 21:41:26 +0800 Subject: [PATCH] [Unity] Support simple dynamic-shape-aware fusion This PR adds support for simple dynamic-shape-aware fusion, which is the first step towards supporting dynamic shapes. The main changes are as follows: - Fix FuncStructInfo in well-formed checks - Renew symbolic var defs in fuse_ops to prevent malformed functions --- src/relax/analysis/well_formed.cc | 12 +++++ src/relax/ir/expr_functor.cc | 5 +- src/relax/transform/fuse_ops.cc | 48 +++++++++++++++++-- .../relax_vm/cuda/cuda_graph_builtin.cc | 4 +- .../python/relax/test_analysis_well_formed.py | 13 +++++ tests/python/relax/test_transform_fuse_ops.py | 37 ++++++++++++++ 6 files changed, 112 insertions(+), 7 deletions(-) diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 9a97931136c8..3eeefd0be584 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -402,6 +402,18 @@ class WellFormedChecker : public relax::ExprVisitor, } } + void VisitStructInfo_(const FuncStructInfoNode* op) final { + if (op->params.defined()) { + WithMode(VisitMode::kMatchVarDef, [&]() { + ICHECK(mode_ == VisitMode::kMatchVarDef); + for (StructInfo param : op->params.value()) { + this->VisitStructInfo(param); + } + }); + } + this->VisitStructInfo(op->ret); + } + void VisitStructInfoExprField(const Expr& expr) final { if (mode_ == VisitMode::kMatchVarDef) { // populate symbolic var in first occurrence diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 174d40053fa0..3f0fc86a2a37 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -577,7 +577,10 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { for (Var param : op->params) { Var new_param = this->VisitVarDef(param); params.push_back(new_param); - all_params_unchanged &= param.same_as(new_param); + if (!param.same_as(new_param)) { + var_remap_[param->vid] = new_param; + all_params_unchanged = false; + } } Expr body = this->VisitWithNewScope(op->body, params); diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 24f068c03f29..8e4346e2062b 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -344,6 +345,45 @@ class GraphCreator : public ExprVisitor { std::unordered_set initialized_nodes_; }; +/*! + * \brief Renew the definition of symbolic vars in Relax. + * \details This mutator is used to prevent the same symbolic var from being used in different + * functions, which is malformed. + */ +class SymbolicVarRenewMutator : public ExprMutator, tir::ExprMutator { + public: + static Function Renew(const Function& function) { + SymbolicVarRenewMutator mutator; + return Downcast(mutator.VisitExpr(function)); + } + + private: + SymbolicVarRenewMutator() = default; + using relax::ExprMutator::VisitExpr; + using relax::ExprMutator::VisitExpr_; + using tir::ExprMutator::VisitExpr_; + + PrimExpr VisitPrimExpr(const PrimExpr& expr) final { return tir::ExprMutator::VisitExpr(expr); } + + // TODO(Siyuan): enhance the method to the following steps: + // 1. Visit and replace all tir::Vars at the definition point + // 2. Revisit the function again and update the use side. + PrimExpr VisitExpr_(const tir::VarNode* op) final { + auto it = var_map_.find(GetRef(op)); + if (it != var_map_.end()) { + return (*it).second; + } else { + auto n = make_object(*op); + tir::Var v(n); + var_map_.Set(GetRef(op), v); + return v; + } + } + + private: + Map var_map_; +}; + /*! * \brief The ExprMutator used to create a new grouped function * \details The workflow of this ExprMutator is: @@ -466,10 +506,10 @@ class FunctionCreator : public ExprMutator { body = builder_->Normalize(body); body = builder_->Normalize(SeqExpr({new_block}, body)); group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1)); - function_ = Function(/*params=*/params_, // - /*body=*/body, // - /*ret_struct_info=*/NullOpt, // - /*attrs=*/DictAttrs(group_attrs)); + function_ = SymbolicVarRenewMutator::Renew(Function(/*params=*/params_, // + /*body=*/body, // + /*ret_struct_info=*/NullOpt, // + /*attrs=*/DictAttrs(group_attrs))); } } diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc index ad8770da8df9..45342cf4ffa2 100644 --- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc @@ -76,11 +76,11 @@ class CUDAGraphCache : public Object { /*! * \brief Launch the cuda graph if it has been cached, otherwise execute it in capture mode. - * \param vm The virutal machine. + * \param vm The virtual machine. * \param capture_func The function of type (args...) -> Tuple[ObjectRef], where 'args' are the * static arguments that are the same for all invocations of the capture function, the returned * tuple contains the intermediate tensors that will be used outside the capture function. - * \params args The static arguments of the capture function + * \param args The static arguments of the capture function * \param entry_index The unique index of the capture function used for lookup. * \return The return value of the capture function. */ diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 49d2b7601137..b4b68504a489 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -520,5 +520,18 @@ def foo(x: R.Tensor(("m", "n"), dtype="float32")) -> R.Tensor(("m1", "n1"), dtyp assert not rx.analysis.well_formed(mod) +def test_func_sinfo_well_formed(): + @R.function + def foo(): + @R.function + def local(x: R.Tensor(["m", "n"], "float32")): + return x + + return local + + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(foo)) + assert rx.analysis.well_formed(mod) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 8f7d8bf40fa5..72f4e29a1690 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -1254,5 +1254,42 @@ def main(inp_0: R.Tensor((1, 784), dtype="float32"), inp_1: R.Tensor((1, 128), d _check(mod, Expected) +def test_symbolic_shape_aware_fuse(): + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor(["n", "m"], "float32")): + with R.dataflow(): + lv0 = R.emit_te(topi.add, x, R.const(1, "float32")) + lv1 = R.emit_te(topi.exp, lv0) + gv = R.emit_te(topi.squeeze, lv1) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @R.function + def fused_add_exp_squeeze( + x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32") + ) -> R.Tensor(["n", "m"], dtype="float32"): + R.func_attr({"Primitive": 1}) + with R.dataflow(): + lv0 = R.emit_te(topi.add, x, p0) + lv1 = R.emit_te(topi.exp, lv0) + gv = R.emit_te(topi.squeeze, lv1) + R.output(gv) + return gv + + @R.function + def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"], dtype="float32"): + cls = Expected + with R.dataflow(): + gv = cls.fused_add_exp_squeeze(x, R.const(1, "float32")) + R.output(gv) + return gv + + _check(Before, Expected) + + if __name__ == "__main__": tvm.testing.main()