diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index b695c5f6c7cf..5ddda937051e 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -31,27 +31,131 @@ namespace tir { // TODO(Siyuan): move it to somewhere under tir folder /*! - * \brief Substitute a given source buffer with a given target buffer in statements or expressions. + * \brief Match symbolic vars according to the given PrimExpr, and update the var_remap. + * Will throw errors if there is a mismatch. */ -class FuseTIRBufferSubstitor : private StmtExprMutator { +class SymbolicMatcher : ExprFunctor { public: - static Stmt Substitute(const Map& buffer_map, Stmt stmt) { - return FuseTIRBufferSubstitor(buffer_map)(std::move(stmt)); + void Match(const Array& lhs, const Array& rhs) { + CHECK_EQ(lhs.size(), rhs.size()); + for (size_t i = 0; i < lhs.size(); ++i) { + Match(lhs[i], rhs[i]); + } } + void Match(const PrimExpr& lhs, const PrimExpr& rhs) { + if (!VisitExpr(lhs, rhs)) { + LOG(FATAL) << "Failed to match PrimExpr " << lhs << " with " << rhs; + } + } + + Map var_remap; private: - explicit FuseTIRBufferSubstitor(const Map& buffer_map) { + bool VisitExpr(const PrimExpr& n, const PrimExpr& other) { + bool matched = n.same_as(other) || ((n->type_index() == other->type_index()) && + n.dtype().code() == other.dtype().code()); + return matched && ExprFunctor::VisitExpr(n, other); + } + +#define TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OpName) \ + bool VisitExpr_(const OpName* op, const PrimExpr& other) { \ + const auto* rhs = other.as(); \ + ICHECK(rhs); \ + return VisitExpr(op->a, rhs->a) && VisitExpr(op->b, rhs->b); \ + } + + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AddNode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(SubNode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MulNode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(DivNode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(ModNode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(EQNode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(NENode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LTNode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(LENode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GTNode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(GENode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(AndNode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(OrNode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MinNode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(MaxNode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorDivNode); + TVM_DECLARE_SYMBOLIC_MATCHER_BINOP(FloorModNode); + + bool VisitExpr_(const IntImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return op->value == rhs->value; + } + + bool VisitExpr_(const FloatImmNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return op->value == rhs->value; + } + + bool VisitExpr_(const CastNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + return VisitExpr(op->value, rhs->value); + } + + bool VisitExpr_(const VarNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + auto lhs = GetRef(op); + if (lhs.same_as(other)) return true; + if (op->dtype.code() != rhs->dtype.code()) return false; + auto it = var_remap.find(lhs); + if (it == var_remap.end()) { + var_remap.Set(lhs, GetRef(rhs)); + return true; + } else { + return (*it).second.same_as(other); + } + } +}; + +/*! + * \brief Substitute a given source buffer with a given target buffer in statements or expressions. + */ +class FuseTIRBufferSubstitor : private StmtExprMutator { + public: + explicit FuseTIRBufferSubstitor(const Map& buffer_map, + const Map& var_map) { + buffer_remap_ = buffer_map; + var_remap_ = var_map; for (const auto& kv : buffer_map) { const Buffer& src = kv.first; const Buffer& tgt = kv.second; - buffer_var_map_[src->data.get()] = tgt; + var_remap_.Set(src->data, tgt->data); } } + Stmt Substitute(Stmt stmt) { return this->VisitStmt(std::move(stmt)); } + + Buffer SubstituteAllocatedBuffer(Buffer buffer) { + ICHECK(buffer_remap_.find(buffer) == buffer_remap_.end()); + Array shape = + MutateArray(buffer->shape, [this](const PrimExpr& expr) { return this->VisitExpr(expr); }); + Array strides = MutateArray( + buffer->strides, [this](const PrimExpr& expr) { return this->VisitExpr(expr); }); + PrimExpr elem_offset = this->VisitExpr(buffer->elem_offset); + if (shape.same_as(buffer->shape) && strides.same_as(buffer->strides) && + elem_offset.same_as(buffer->elem_offset)) { + return buffer; + } else { + auto n = make_object(*buffer.get()); + n->shape = std::move(shape); + n->strides = std::move(strides); + n->elem_offset = std::move(elem_offset); + Buffer new_buffer(n); + this->buffer_remap_.Set(buffer, new_buffer); + return new_buffer; + } + } + + private: PrimExpr VisitExpr_(const VarNode* _op) final { - auto it = buffer_var_map_.find(_op); - if (it != buffer_var_map_.end()) { - return it->second->data; + auto it = var_remap_.find(GetRef(_op)); + if (it != var_remap_.end()) { + return (*it).second; } else { return GetRef(_op); } @@ -59,25 +163,25 @@ class FuseTIRBufferSubstitor : private StmtExprMutator { PrimExpr VisitExpr_(const BufferLoadNode* _op) final { BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); - auto it = buffer_var_map_.find(load->buffer->data.get()); - if (it != buffer_var_map_.end()) { + const Buffer& buffer = SubstituteBuffer(load->buffer); + if (buffer.same_as(load->buffer)) { + return std::move(load); + } else { auto n = make_object(*load.get()); - n->buffer = it->second; + n->buffer = buffer; return BufferLoad(n); - } else { - return std::move(load); } } Stmt VisitStmt_(const BufferStoreNode* _op) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); - auto it = buffer_var_map_.find(store->buffer->data.get()); - if (it != buffer_var_map_.end()) { - auto n = CopyOnWrite(store.get()); - n->buffer = it->second; - return BufferStore(n); - } else { + const Buffer& buffer = SubstituteBuffer(store->buffer); + if (buffer.same_as(store->buffer)) { return std::move(store); + } else { + auto n = make_object(*store.get()); + n->buffer = buffer; + return BufferStore(n); } } @@ -85,21 +189,25 @@ class FuseTIRBufferSubstitor : private StmtExprMutator { Block block = Downcast(StmtMutator::VisitStmt_(_op)); // Define the mutation functions. + auto f_mutate_match_buffers = [this](const MatchBufferRegion& match_buffer) { - const Buffer& src_buffer = match_buffer->source->buffer; - auto it = buffer_var_map_.find(src_buffer->data.get()); - if (it != buffer_var_map_.end()) { - return MatchBufferRegion(match_buffer->buffer, - BufferRegion(it->second, match_buffer->source->region)); - } else { + const Buffer& src_buffer = SubstituteBuffer(match_buffer->source->buffer); + const Buffer& tgt_buffer = SubstituteAllocatedBuffer(match_buffer->buffer); + if (src_buffer.same_as(match_buffer->source->buffer) && + tgt_buffer.same_as(match_buffer->buffer)) { return match_buffer; + } else { + auto n = make_object(*match_buffer.get()); + n->buffer = tgt_buffer; + n->source = BufferRegion(src_buffer, match_buffer->source->region); + return MatchBufferRegion(n); } }; auto f_mutate_read_write_region = [this](const BufferRegion& buffer_region) { - auto it = buffer_var_map_.find(buffer_region->buffer->data.get()); - return it == buffer_var_map_.end() ? buffer_region - : BufferRegion(it->second, buffer_region->region); + auto it = buffer_remap_.find(buffer_region->buffer); + return it == buffer_remap_.end() ? buffer_region + : BufferRegion((*it).second, buffer_region->region); }; // Step 1. Mutate `match_buffers`. @@ -108,26 +216,34 @@ class FuseTIRBufferSubstitor : private StmtExprMutator { // Step 2. Mutate the read/write region. Array reads = MutateArray(block->reads, f_mutate_read_write_region); Array writes = MutateArray(block->writes, f_mutate_read_write_region); + // Step 3. Mutate the Allocate Buffers. + Array alloc_buffers = MutateArray(block->alloc_buffers, [this](const Buffer& buffer) { + return SubstituteAllocatedBuffer(buffer); + }); reads = UnionAccessRegion(reads); writes = UnionAccessRegion(writes); if (reads.same_as(block->reads) && // writes.same_as(block->writes) && // - match_buffers.same_as(block->match_buffers)) { + match_buffers.same_as(block->match_buffers) && + alloc_buffers.same_as(block->alloc_buffers)) { return std::move(block); } else { auto n = CopyOnWrite(block.get()); n->reads = std::move(reads); n->writes = std::move(writes); n->match_buffers = std::move(match_buffers); + n->alloc_buffers = std::move(alloc_buffers); return Block(n); } } private: - /*! \brief Mapping from src buffer.data to tgt buffer. */ - std::unordered_map buffer_var_map_; + /*! \brief Mapping from src buffer to tgt buffer. */ + Map buffer_remap_; + /*! \brief Mapping from src tir var to tgt var. */ + Map var_remap_; /*! \brief The structural equality checker */ StructuralEqual structural_equal_; @@ -155,6 +271,15 @@ class FuseTIRBufferSubstitor : private StmtExprMutator { return ret; } } + + inline Buffer SubstituteBuffer(const Buffer& buffer) const { + auto it = buffer_remap_.find(buffer); + if (it != buffer_remap_.end()) { + return (*it).second; + } else { + return buffer; + } + } }; /*! \brief A mutator which detect block name duplication and deduplicate the names. */ @@ -298,8 +423,8 @@ class FusedTIRConstructor : public ExprVisitor { // Step 5. Map input arguments to buffer MapInputBuffer(prim_func, call->args[1]); - size_t num_output_buffers = GetCallTIROutputSize(call); - AllocateIntermediateBuffer(GetRef(call), prim_func, num_output_buffers); + const Array>& output_buffer_shapes = GetCallTIROutputShapes(call); + AllocateIntermediateBuffer(GetRef(call), prim_func, output_buffer_shapes); // Update fused func name func_info_.global_name += "_" + gv->name_hint; } @@ -343,14 +468,32 @@ class FusedTIRConstructor : public ExprVisitor { * \brief Get the number of outputs for a call_tir node. * \return The number of outputs. */ - static size_t GetCallTIROutputSize(const CallNode* call) { + static Array> GetCallTIROutputShapes(const CallNode* call) { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); ICHECK(call->op.same_as(call_tir_op_)); ICHECK_EQ(call->sinfo_args.size(), 1); + auto get_tensor_shape = [](const TensorStructInfoNode* sinfo) { + const auto* shape_expr = sinfo->shape.as(); + CHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with symbolic shape."; + return shape_expr->values; + }; if (const auto* tuple_sinfo = call->sinfo_args[0].as()) { - return tuple_sinfo->fields.size(); + Array> shapes; + for (const StructInfo& field : tuple_sinfo->fields) { + const auto* tensor_sinfo = field.as(); + CHECK(tensor_sinfo) << "CallTIR sinfo_args are expected to be TensorStructInfo or Tuple of " + "TensorStructInfo, but got " + << call->sinfo_args[0]; + shapes.push_back(get_tensor_shape(tensor_sinfo)); + } + return shapes; + } else if (const auto* tensor_sinfo = call->sinfo_args[0].as()) { + return {get_tensor_shape(tensor_sinfo)}; } else { - return 1; + CHECK(tensor_sinfo) << "CallTIR sinfo_args are expected to be TensorStructInfo or Tuple of " + "TensorStructInfo, but got " + << call->sinfo_args[0]; + throw; } } @@ -365,17 +508,14 @@ class FusedTIRConstructor : public ExprVisitor { for (const tir::Buffer& target_buffer : (*it).second) { ICHECK_LT(buffer_idx, buffers.size()); const tir::Buffer& buffer = buffers[buffer_idx]; - // TODO(relax-team): Add support for symbolic shape fusion - for (const PrimExpr& shape_expr : buffer->shape) { - ICHECK(shape_expr.as()) << "Only support constant shape fusion for now"; - } + func_info_.symbolic_var_matcher.Match(buffer->shape, target_buffer->shape); func_info_.buffer_subst_map.Set(buffer, target_buffer); buffer_idx++; } } } } - // Make sure every buffers are maped. + // Make sure every buffers are mapped. ICHECK_EQ(buffer_idx, buffers.size()); } @@ -408,18 +548,30 @@ class FusedTIRConstructor : public ExprVisitor { * intermediate results. * \param expr The relax Expr, which can be binding vars or binding values. * \param func The old TIR PrimFunc - * \param output_size The number of output params. All output params are at the end of param list. + * \param output_shapes The shape of output params. */ - void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func, size_t output_size) { + void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func, + const Array>& output_shapes) { size_t n = func->params.size(); + size_t output_size = output_shapes.size(); ICHECK_GE(n, output_size); // Allocate intermediate buffer Array alloc_buffers; for (size_t i = 0; i < output_size; ++i) { const tir::Var& param = func->params[n - output_size + i]; const tir::Buffer& buffer = func->buffer_map.at(param); - func_info_.alloc_buffers.push_back(buffer); - alloc_buffers.push_back(buffer); + + // Update buffer with new symbolic shape according to the sinfo + auto n = make_object(*buffer.get()); + n->shape = output_shapes[i]; + n->name = param->name_hint + "_intermediate"; + tir::Buffer new_buffer(n); + func_info_.alloc_buffers.push_back(new_buffer); + alloc_buffers.push_back(new_buffer); + + // Match the shape of the output buffer with the shape + func_info_.symbolic_var_matcher.Match(buffer->shape, n->shape); + func_info_.buffer_subst_map.Set(buffer, new_buffer); } // Update expr2buffers func_info_.expr2buffers.Set(expr, alloc_buffers); @@ -438,7 +590,7 @@ class FusedTIRConstructor : public ExprVisitor { Array params; Array buffers; if (const auto* tensor = struct_info.as()) { - // Case 1. the relax param is a DynTensor, we directly create a tir var and buffer + // Case 1. the relax param is a Tensor, we directly create a tir var and buffer const auto* shape_expr = tensor->shape.as(); ICHECK(shape_expr) << "FuseTIR expects all parameters are Tensors with symbolic shape."; @@ -452,7 +604,7 @@ class FusedTIRConstructor : public ExprVisitor { params.push_back(std::move(param)); buffers.push_back(std::move(buffer)); } else if (const auto* tuple = struct_info.as()) { - // Case 2. the relax param is a Tuple, we recursively visit each field until it's a DynTensor + // Case 2. the relax param is a Tuple, we recursively visit each field until it's a Tensor // Enable postfix if (index == -1) index = 0; for (size_t i = 0; i < tuple->fields.size(); ++i) { @@ -478,21 +630,25 @@ class FusedTIRConstructor : public ExprVisitor { tir::PrimFunc ConstructFunc() { Map attr_map; attr_map.Set("tir.noalias", tir::const_true()); + tir::FuseTIRBufferSubstitor substitor(func_info_.buffer_subst_map, + func_info_.symbolic_var_matcher.var_remap); ICHECK(func_info_.global_name != "fused"); // Remove output buffers from func_info_.alloc_buffers Array alloc_buffers; for (const tir::Buffer& buf : func_info_.alloc_buffers) { if (func_info_.output_buffers.count(buf.get()) == 0) { - alloc_buffers.push_back(buf); + alloc_buffers.push_back(substitor.SubstituteAllocatedBuffer(buf)); } } tir::Stmt body = tir::BlockNameDeduplicator()(tir::SeqStmt::Flatten(func_info_.bodies)); - body = tir::FuseTIRBufferSubstitor::Substitute(func_info_.buffer_subst_map, body); + + body = substitor.Substitute(body); body = tir::Block({}, {}, {}, "root", std::move(body), NullOpt, alloc_buffers); body = tir::BlockRealize({}, Bool(true), Downcast(body)); tir::PrimFunc func(func_info_.params, body, VoidType(), func_info_.buffer_map, DictAttrs(attr_map)); - return func; + // Renew function defs to prevent using the same symbolic vars in different functions + return tir::RenewDefs(func); } /*! \brief Get DynTensor numbers from recursive Tuples. */ @@ -539,6 +695,8 @@ class FusedTIRConstructor : public ExprVisitor { std::unordered_set output_buffers; /*! \brief The name of the fused function */ std::string global_name = "fused"; + /*! \brief The map from symbolic var to its corresponding var in the fused function */ + tir::SymbolicMatcher symbolic_var_matcher; }; /*! \brief The IRModule */ diff --git a/src/runtime/contrib/json/json_runtime.h b/src/runtime/contrib/json/json_runtime.h index 51ce2cffd780..5409078e8599 100644 --- a/src/runtime/contrib/json/json_runtime.h +++ b/src/runtime/contrib/json/json_runtime.h @@ -59,7 +59,7 @@ class JSONRuntimeBase : public ModuleNode { const char* type_key() const override { return "json"; } // May be overridden /*! \brief Get the property of the runtime module .*/ - int GetPropertyMask() const { + int GetPropertyMask() const override { return ModulePropertyMask::kBinarySerializable | ModulePropertyMask::kRunnable; } diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 356e28d6e910..bdbd9be966de 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -698,5 +698,100 @@ def main(x: R.Tensor((2, 3), "float32")): _check(Module, Module) +def test_symbolic_shape_aware_fuse(): + @I.ir_module + class Before: + @R.function + def fused_add_exp_squeeze( + x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32") + ) -> R.Tensor(["n", "m"], dtype="float32"): + R.func_attr({"Primitive": 1}) + with R.dataflow(): + lv0 = R.emit_te(topi.add, x, p0) + lv1 = R.emit_te(topi.exp, lv0) + gv = R.emit_te(topi.squeeze, lv1) + R.output(gv) + return gv + + @R.function + def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"], dtype="float32"): + cls = Before + with R.dataflow(): + gv = cls.fused_add_exp_squeeze(x, R.const(1, "float32")) + R.output(gv) + return gv + + def fused_add_exp_squeeze(x, p0): + return topi.squeeze(topi.exp(topi.add(x, p0))) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor(["n", "m"], "float32")) -> R.Tensor(["n", "m"], dtype="float32"): + with R.dataflow(): + gv = R.emit_te(fused_add_exp_squeeze, x, R.const(1, "float32")) + R.output(gv) + return gv + + _check(Before, Expected) + + +def test_symbolic_shape_aware_fuse_with_allocation(): + def te_mean(x, axis): + return topi.divide(topi.sum(x, axis, keepdims=True), 4096) + + @I.ir_module + class Before: + @R.function + def fused_mean_add_tir_sqrt_divide_multiply( + x: R.Tensor((1, "n", 4096), dtype="float32"), + y: R.Tensor((1, "n", 4096), dtype="float32"), + rms_norm_weight: R.Tensor((4096,), dtype="float32"), + ) -> R.Tensor((1, "n", 4096), dtype="float32"): + R.func_attr({"Primitive": 1}) + with R.dataflow(): + lv0 = R.emit_te(te_mean, x, axis=2) + lv1 = R.emit_te(topi.add, lv0, lv0) + lv2 = R.emit_te(topi.sqrt, lv1) + lv3 = R.emit_te(topi.divide, y, lv2) + gv = R.emit_te(topi.multiply, rms_norm_weight, lv3) + R.output(gv) + return gv + + @R.function + def main( + x: R.Tensor((1, "n", 4096), dtype="float32"), + y: R.Tensor((1, "n", 4096), dtype="float32"), + rms_norm_weight: R.Tensor((4096,), dtype="float32"), + ) -> R.Tensor((1, "n", 4096), dtype="float32"): + cls = Before + with R.dataflow(): + gv = cls.fused_mean_add_tir_sqrt_divide_multiply(x, y, rms_norm_weight) + R.output(gv) + return gv + + def fused_mean_add_tir_sqrt_divide_multiply(x, y, rms_norm_weight): + lv0 = te_mean(x, axis=2) + lv1 = topi.add(lv0, lv0) + lv2 = topi.sqrt(lv1) + lv3 = topi.divide(y, lv2) + return topi.multiply(rms_norm_weight, lv3) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((1, "n", 4096), dtype="float32"), + y: R.Tensor((1, "n", 4096), dtype="float32"), + rms_norm_weight: R.Tensor((4096,), dtype="float32"), + ) -> R.Tensor((1, "n", 4096), dtype="float32"): + with R.dataflow(): + gv = R.emit_te(fused_mean_add_tir_sqrt_divide_multiply, x, y, rms_norm_weight) + R.output(gv) + return gv + + _check(Before, Expected) + + if __name__ == "__main__": tvm.testing.main()