From db735e3d8a2e60a8f2ed603b2342d72f2e16bc7b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 22 Jan 2024 17:04:10 +0000 Subject: [PATCH] [Unity][Transform] Check for infer-able expressions in FuseOps Prior to this commit, `FuseOps` and `FuseOpsByPattern` exposed a symbolic variable to the fused function if it was used within the fused function, but wasn't inferable from other parameter shapes. While this prevents undefined symbolic variables, it can cause issues for downstream use of `CodegenJSON`, which requires all arguments to be tensors, or tuple of tensors. Frequently, all uses of a non-inferable symbolic shape occur within a symbolic expression that can be inferred. For example, a function that takes `arg: R.Tensor([N+1])` and returns `R.add(arg, R.const(1))` cannot infer `N`. However, all occurrences of `N` occur as part of the expression `N+1`, and the value of `N+1` can be inferred. Therefore, if we replace `N+1` with `M`, the additional `ShapeTuple` argument isn't required. --- src/relax/transform/fuse_ops.cc | 206 +++++++++++++++++- src/support/ordered_set.h | 25 ++- .../test_transform_fuse_ops_by_pattern.py | 77 +++++++ 3 files changed, 292 insertions(+), 16 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index a2a3e96dd567..feab261e3076 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -38,9 +38,11 @@ #include #include +#include #include "../../relay/analysis/graph_partitioner.h" #include "../../support/arena.h" +#include "../../support/ordered_set.h" #include "tvm/relax/expr.h" #include "utils.h" @@ -360,6 +362,169 @@ class GraphCreator : public ExprVisitor { std::unordered_set initialized_nodes_; }; +class InferredCommonSubexpressionCollector : relax::ExprVisitor, + StructInfoVisitor, + tir::ExprVisitor { + public: + struct InferResult { + // A list of additional symbolic variables that must be provided + // to the function. These variables cannot be inferred from the + // StructInfo of the existing parameters. + Array symbolic_vars; + + // A list of expressions, each of which must be remapped to a new + // symbolic variable. These expressions can be inferred from the + // StructInfo of the existing parameters, but may contain + // sub-expressions that cannot. + Array symbolic_expressions; + }; + + static InferResult Infer(Array params, Expr body) { + InferredCommonSubexpressionCollector collector; + collector.VisitStructInfo(TupleStructInfo(params.Map(GetStructInfo))); + collector.phase_ = Phase::CollectRequiredExpressions; + collector.VisitExpr(body); + + return InferResult{ + Array(collector.required_symbolic_vars_.begin(), + collector.required_symbolic_vars_.end()), + Array(collector.required_symbolic_exprs_.begin(), + collector.required_symbolic_exprs_.end()), + }; + } + + private: + using relax::ExprVisitor::VisitExpr; + using relax::ExprVisitor::VisitExpr_; + using tir::ExprVisitor::VisitExpr; + using tir::ExprVisitor::VisitExpr_; + + void VisitExprDepStructInfoField(const StructInfo& struct_info) override { + VisitStructInfo(struct_info); + } + void VisitStructInfoExprField(const Expr& expr) override { VisitStructInfo(GetStructInfo(expr)); } + void VisitStructInfoExprField(const PrimExpr& expr) override { + if (expr->IsInstance()) { + return; + } + + switch (phase_) { + case Phase::CollectInferableExpressions: + inferable_expressions_.insert(expr); + break; + + case Phase::CollectRequiredExpressions: + VisitExpr(expr); + break; + + default: + LOG(FATAL) << "Invalid value for Phase: " << static_cast(phase_); + break; + } + } + + void VisitExpr(const PrimExpr& expr) override { + if (inferable_expressions_.count(expr)) { + required_symbolic_exprs_.insert(expr); + } else { + tir::ExprVisitor::VisitExpr(expr); + } + } + + void VisitExpr_(const tir::VarNode* op) override { + required_symbolic_vars_.push_back(GetRef(op)); + } + + enum class Phase { + CollectInferableExpressions, + CollectRequiredExpressions, + }; + Phase phase_ = Phase::CollectInferableExpressions; + std::unordered_set inferable_expressions_; + support::OrderedSet required_symbolic_vars_; + support::OrderedSet required_symbolic_exprs_; +}; + +/* \brief Replace occurrences of a PrimExpr in the symbolic variables + * + * In most cases, the `tvm::relax::Bind` utility should be used + * instead. Here, though, we are replacing a `PrimExpr` with a + * `tir::Var`, whereas `tvm::relax::Bind` supports the more standard + * case of replacing a `tir::Var` with a `PrimExpr`. + */ +class SymbolicSubexprReplacer : relax::ExprMutator, StructInfoMutator, tir::ExprMutator { + public: + /* \brief Replace occurrences of a PrimExpr in the symbolic variables + * + * In most cases, the `tvm::relax::Bind` utility should be used + * instead. Here, though, we are replacing a `PrimExpr` with a + * `tir::Var`, rather than the other way around. + * + * \param relax_expr The expression in which to replace symbolic expressions + * + * \param symbolic_exprs A list of expressions, each of which should + * be replaced with a new symbolic variable. This is provided as a + * list, rather than as a replacement map, to allow context-dependent + * names to be generated for these expressions. + * + * \returns The updated relax expression. + */ + static Expr Replace(const Expr& relax_expr, Array symbolic_exprs) { + std::unordered_map, StructuralHash, StructuralEqual> replacements; + for (const auto& expr : symbolic_exprs) { + replacements.insert({expr, NullOpt}); + } + + SymbolicSubexprReplacer mutator(replacements); + return mutator(relax_expr); + } + + private: + using relax::ExprMutator::operator(); + using relax::ExprMutator::VisitExpr; + using tir::ExprMutator::operator(); + using tir::ExprMutator::VisitExpr; + + SymbolicSubexprReplacer( + std::unordered_map, StructuralHash, StructuralEqual> + replacements) + : replacements_(replacements) {} + + StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override { + return VisitStructInfo(struct_info); + } + Expr VisitStructInfoExprField(const Expr& expr) override { return VisitExpr(expr); } + PrimExpr VisitStructInfoExprField(const PrimExpr& expr) override { return VisitExpr(expr); } + PrimExpr VisitPrimExpr(const PrimExpr& expr) override { return VisitExpr(expr); } + + PrimExpr VisitExpr(const PrimExpr& expr) override { + if (auto replacement = GetReplacement(expr)) { + return replacement.value(); + } else { + return tir::ExprMutator::VisitExpr(expr); + } + } + + Optional GetReplacement(const PrimExpr& expr) { + auto it = replacements_.find(expr); + if (it == replacements_.end()) { + return NullOpt; + } + + Optional& opt_var = it->second; + if (!opt_var.defined()) { + // Ideally, this path would never be reached, as it doesn't + // provide as much context in the variable name. However, it's + // useful as a fallback. + opt_var = tir::Var("fused_expr", expr->dtype); + } + + return opt_var.value(); + } + + std::unordered_map, StructuralHash, StructuralEqual> replacements_; +}; + /*! * \brief The ExprMutator used to create a new grouped function * \details The workflow of this ExprMutator is: @@ -533,25 +698,44 @@ class FunctionCreator : public ExprMutator { function_ = NullOpt; } else { Expr body = outputs.size() == 1 ? outputs[0] : Tuple(outputs); + body = SeqExpr({new_block}, body); body = builder_->Normalize(body); - body = builder_->Normalize(SeqExpr({new_block}, body)); + + // Any symbolic variables that are required within the body of + // the function, but cannot be inferred from the parameters of + // the function, must be exposed using an additional argument. + auto [symbolic_vars, symbolic_expressions] = + InferredCommonSubexpressionCollector::Infer(params_, body); + if (symbolic_vars.size()) { + auto symbolic_vars_as_expr = + symbolic_vars.Map([](tir::Var var) -> PrimExpr { return var; }); + params_.push_back(Var("tir_vars", ShapeStructInfo(symbolic_vars_as_expr))); + arguments_.push_back(ShapeExpr(symbolic_vars_as_expr)); + } + group_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1)); Function function = Function(/*params=*/params_, // /*body=*/body, // /*ret_struct_info=*/NullOpt, // /*is_pure=*/true, // /*attrs=*/DictAttrs(group_attrs)); - Array free_vars = - FreeSymbolicVars(function).Map([](const tir::Var& var) -> PrimExpr { return var; }); - if (!free_vars.empty()) { - params_.push_back(Var("tir_vars", ShapeStructInfo(free_vars))); - arguments_.push_back(ShapeExpr(free_vars)); - function = Function(/*params=*/params_, // - /*body=*/body, // - /*ret_struct_info=*/NullOpt, // - /*is_pure=*/true, // - /*attrs=*/DictAttrs(group_attrs)); + + // If the function contains symbolic expressions that can be + // inferred from the parameters, but contain subexpressions that + // cannot be inferred from the parameters, those expressions + // should be replaced with symbolic variables. + // + // For example, suppose a fused function maps from a tensor of + // shape `[batch_size+1, 1024]` to `[batch_size+1,1024]`. It + // cannot infer `batch_size`, but could infer the value of + // `batch_size+1`. By introducing `batch_size_plus_one = + // batch_size+1`, we can rely on just the infer-able symbolic + // vars. + if (symbolic_expressions.size()) { + function = + Downcast(SymbolicSubexprReplacer::Replace(function, symbolic_expressions)); } + function_ = SymbolicVarRenewMutator::Renew(function); } } diff --git a/src/support/ordered_set.h b/src/support/ordered_set.h index 741f0b18e6b9..50536f2310e6 100644 --- a/src/support/ordered_set.h +++ b/src/support/ordered_set.h @@ -26,6 +26,7 @@ #include +#include #include #include @@ -39,17 +40,31 @@ namespace detail { */ template struct OrderedSetLookupType { - using MapType = std::unordered_map::iterator>; + using Hash = std::hash; + using Equal = std::equal_to; }; template struct OrderedSetLookupType>> { - using MapType = std::unordered_map::iterator, runtime::ObjectPtrHash, - runtime::ObjectPtrEqual>; + using Hash = runtime::ObjectPtrHash; + using Equal = runtime::ObjectPtrEqual; }; } // namespace detail -template +/* \brief Utility to hold an ordered set + * + * \tparam T The type held by the OrderedSet + * + * \tparam LookupHash The hash implementation to use for detecting + * duplicate entries. If unspecified, defaults to `ObjectPtrHash` for + * TVM types, and `std::hash` otherwise. + * + * \tparam LookupEqual The equality-checker to use for detecting + * duplicate entries. If unspecified, defaults to `ObjectPtrEqual` + * for TVM types, and `std::equal_to` otherwise. + */ +template ::Hash, + typename LookupEqual = typename detail::OrderedSetLookupType::Equal> class OrderedSet { public: OrderedSet() = default; @@ -91,7 +106,7 @@ class OrderedSet { private: std::list elements_; - typename detail::OrderedSetLookupType::MapType elem_to_iter_; + std::unordered_map::iterator, LookupHash, LookupEqual> elem_to_iter_; }; } // namespace support diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 5e700b277f32..b331570d622f 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -1217,5 +1217,82 @@ def inner_func( tvm.ir.assert_structural_equal(Expected, After) +def test_matmul_symbolic_expr(): + """Like `test_matmul_symbolic_var`, but with a PrimExpr shape + + The shape of weights used in the matmul are `[1024, M + 1024]`, + which can result from `CombineParallelMatmul`. If the fused + function is written in terms of `M`, then `M` must be provided as + an additional `ShapeExpr`, as it cannot be inferred from the + tensor shape. This can cause issues for downstream passes, as + CodeGenJSON, used by the TVM's runtime for cublas and cutlass, + only supports `R.Tensor` and tuples of `R.Tensor`. + + If a symbolic variable is only used within expressions that + themselves are inferable from the tensor shapes, then the fused + function could be written in terms of that expression, removing + the need for the `ShapeExpr`. Here, the expression `M + 1024` is + replaced by the variable `w2_size`. + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor(["batch_size", 1024], dtype="float16"), + w1: R.Tensor([1024, 1024], dtype="float16"), + w2: R.Tensor([1024, "M"], dtype="float16"), + ) -> R.Tensor(["batch_size", "M + 1024"], "float16"): + with R.dataflow(): + concat = R.concat([w1, w2], axis=1) + out = R.matmul(x, concat) + R.output(out) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor(["batch_size", 1024], dtype="float16"), + w1: R.Tensor([1024, 1024], dtype="float16"), + w2: R.Tensor([1024, "M"], dtype="float16"), + ) -> R.Tensor(["batch_size", "M + 1024"], "float16"): + cls = Expected + with R.dataflow(): + concat = R.concat([w1, w2], axis=1) + out = cls.fused_relax_matmul_cublas(x, concat) + R.output(out) + return out + + @R.function + def fused_relax_matmul_cublas( + x: R.Tensor(["batch_size", 1024], dtype="float16"), + w2: R.Tensor([1024, "w2_size"], dtype="float16"), + ) -> R.Tensor(["batch_size", "w2_size"], dtype="float16"): + batch_size = T.int64() + w2_size = T.int64() + R.func_attr({"Codegen": "cublas"}) + + @R.function + def inner_func( + x: R.Tensor([batch_size, 1024], dtype="float16"), + w2: R.Tensor((1024, w2_size), dtype="float16"), + ) -> R.Tensor([batch_size, w2_size], dtype="float16"): + R.func_attr({"Composite": "cublas.matmul"}) + with R.dataflow(): + out = R.matmul(x, w2) + R.output(out) + return out + + out = inner_func(x, w2) + return out + + patterns = relax.backend.pattern_registry.get_patterns_with_prefix("cublas.matmul") + After = relax.transform.FuseOpsByPattern(patterns, bind_constants=False, annotate_codegen=True)( + Before + ) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": pytest.main([__file__])