diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index b0eeba399e90..32780f6dd253 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -183,6 +183,8 @@ class GraphCreator : public ExprVisitor { ICHECK_NOTNULL(binding_var_node); static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); + OpPatternKind pattern = OpPatternKind::kOpaque; Array args = call->args; @@ -191,7 +193,7 @@ class GraphCreator : public ExprVisitor { // - Otherwise, the pattern of the current binding variable node is set to `kOpaque`, and we // recurse into the call expression. const auto* op = call->op.as(); - if (op == call_tir_op_.get()) { + if (op == call_tir_op_.get() || op == call_tir_inplace_op_.get()) { const GlobalVar& global_var = Downcast(call->args[0]); tir::PrimFunc func = Downcast(mod_->Lookup(global_var)); @@ -377,7 +379,8 @@ class FunctionCreator : public ExprMutator { * function accordingly * \param binding The binding to be appended * \note Allowed bindings are: - * - VarBinding with value being a call node calling `relax.call_tir`. + * - VarBinding with value being a call node calling `relax.call_tir` or + * `relax.call_tir_inplace`. * - VarBinding with value being a tuple-get-item node. * // TODO(tvm-team): handle match shape */ @@ -387,7 +390,8 @@ class FunctionCreator : public ExprMutator { if (const auto* var_binding = binding.as()) { if (const auto* call = var_binding->value.as()) { - if (call->op == Op::Get("relax.call_tir")) { + if (call->op == Op::Get("relax.call_tir") || + call->op == Op::Get("relax.call_tir_inplace")) { // Update the name of the function. name_hint_ = name_hint_ + "_" + Downcast(call->args[0])->name_hint; diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 1c25229d88f8..4ad291e91cce 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include #include @@ -367,9 +368,10 @@ class FusedTIRConstructor : public ExprVisitor { * \brief Construct a fused TIR PrimFunc from a relax sub-function * \param mod The IRModule * \param gv The global var of relax subfunction to be fused into one PrimFunc - * \return The fused TIR PrimFunc + * \return The fused TIR PrimFunc and the in-place indices (non-empty for an in-place call) */ - static tir::PrimFunc GetFusedTIR(const IRModule& mod, const GlobalVar& gv) { + static std::pair> GetFusedTIR(const IRModule& mod, + const GlobalVar& gv) { FusedTIRConstructor visitor(mod, gv->name_hint); BaseFunc f = mod->Lookup(gv); CHECK(f->IsInstance()) @@ -377,7 +379,11 @@ class FusedTIRConstructor : public ExprVisitor { CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive)) << "Expected a function with attr `kPrimitive`"; visitor(Downcast(f)); - return visitor.fused_tir_; + Array inplace_indices; + for (size_t idx : visitor.inplace_indices_) { + inplace_indices.push_back(Integer(idx)); + } + return {visitor.fused_tir_, inplace_indices}; } private: @@ -438,9 +444,38 @@ class FusedTIRConstructor : public ExprVisitor { auto it = func_info_.expr2buffers.find(body); ICHECK(it != func_info_.expr2buffers.end()) << "Fail to detect output buffers for function body"; + const Array& buffers = (*it).second; + + // map of input buffers to indices (helpful for detecting in-place inputs) + std::unordered_map buffer_to_idx; + std::unordered_map input_to_idx; + for (size_t i = 0; i < func_info_.params.size(); i++) { + input_to_idx[func_info_.params[i]] = i; + } + for (auto [var, buffer] : func_info_.buffer_map) { + if (auto it = input_to_idx.find(var); it != input_to_idx.end()) { + buffer_to_idx[buffer] = (*it).second; + } + } + + // numbered separately because the number of output *vars* might differ from the + // number of outputs if there are in-place inputs + int out_idx = 0; for (size_t i = 0; i < buffers.size(); ++i) { - tir::Var param = tir::Var("p_output" + std::to_string(i), PrimType(DataType::Handle())); + // Do not add output vars for in-place inputs + // (i.e., already listed in the buffer map. This would result + // in duplicates in the buffer map otherwise) + if (auto it = buffer_to_idx.find(buffers[i]); it != buffer_to_idx.end()) { + auto idx = (*it).second; + CHECK(!inplace_indices_.count(idx)) + << "In-place index " << idx << " used twice! An argument must be aliased."; + inplace_indices_.insert(idx); + continue; + } + + tir::Var param = tir::Var("p_output" + std::to_string(out_idx), PrimType(DataType::Handle())); + out_idx++; func_info_.buffer_map.Set(param, buffers[i]); func_info_.params.push_back(param); func_info_.output_buffers.insert(buffers[i].get()); @@ -476,8 +511,11 @@ class FusedTIRConstructor : public ExprVisitor { void VisitExpr_(const CallNode* call) final { ExprVisitor::VisitExpr_(call); static const Op& call_tir_op_ = Op::Get("relax.call_tir"); - ICHECK(call->op == call_tir_op_) - << "Only call_tir is supported in primitive function, but got: " << GetRef(call); + static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); + + ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_) + << "Only call_tir and call_tir_inplace are supported in primitive function, but got: " + << GetRef(call); // Step 1. Get Global var and PrimFunc GlobalVar gv = Downcast(call->args[0]); @@ -503,7 +541,7 @@ class FusedTIRConstructor : public ExprVisitor { MapInputBuffer(prim_func, call->args[1]); const Array>& output_buffer_shapes = GetCallTIROutputShapes(call); - AllocateIntermediateBuffer(GetRef(call), prim_func, output_buffer_shapes); + AllocateIntermediateBuffer(call, prim_func, output_buffer_shapes); // Step 6. Update tir_vars if (call->args.size() > 2) { @@ -566,7 +604,8 @@ class FusedTIRConstructor : public ExprVisitor { */ static Array> GetCallTIROutputShapes(const CallNode* call) { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); - ICHECK(call->op.same_as(call_tir_op_)); + static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); + ICHECK(call->op.same_as(call_tir_op_) || call->op.same_as(call_tir_inplace_op_)); ICHECK_EQ(call->sinfo_args.size(), 1); auto get_tensor_shape = [](const TensorStructInfoNode* sinfo) { const auto* shape_expr = sinfo->shape.as(); @@ -611,7 +650,7 @@ class FusedTIRConstructor : public ExprVisitor { } } } - // Make sure every buffers are mapped. + // Make sure every buffer is mapped. ICHECK_EQ(buffer_idx, buffers.size()); } @@ -639,28 +678,49 @@ class FusedTIRConstructor : public ExprVisitor { MapArgsToBuffer(arg_list, buffer_list); } - static Array GetPrimFuncOutputParams(const tir::PrimFunc& func, size_t output_size) { + static Array GetInplaceOutputIndices(const Array& inplace_indices, + int num_inputs) { + Array ret; + int last_idx = num_inputs; + for (auto idx : inplace_indices) { + int i = idx.IntValue(); + if (i >= 0) { + ret.push_back(Integer(i)); + } else { + ret.push_back(Integer(last_idx)); + last_idx++; + } + } + + return ret; + } + + static Array GetPrimFuncOutputParams(const tir::PrimFunc& func, + const Array& output_indices) { size_t n = func->params.size(); int symbolic_var_index = -1; + size_t output_size = output_indices.size(); ICHECK_GE(n, output_size); - for (size_t i = 0; i < n; ++i) { - const tir::Var& param = func->params[i]; + + Array ret; + for (auto idx : output_indices) { + int i = idx.IntValue(); + const tir::Var& param = func->params[static_cast(i)]; if (param->dtype.is_int() || param->dtype.is_uint()) { if (symbolic_var_index == -1) symbolic_var_index = i; } else if (param->dtype.is_handle()) { CHECK(symbolic_var_index == -1) << "The scalar input should be at the ending of the " "parameter list."; + ret.push_back(param); } else { LOG(FATAL) << "The params of PrimFunc are expected to be Buffer handle or scalar, but got: " << param->dtype; } } + size_t end_index = symbolic_var_index == -1 ? n : symbolic_var_index; ICHECK_GE(end_index, output_size); - size_t begin_index = end_index - output_size; - Array output_params{func->params.begin() + begin_index, - func->params.begin() + end_index}; - return output_params; + return ret; } /*! @@ -670,18 +730,39 @@ class FusedTIRConstructor : public ExprVisitor { * \param func The old TIR PrimFunc * \param output_shapes The shape of output params. */ - void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func, + void AllocateIntermediateBuffer(const CallNode* call, const tir::PrimFunc& func, const Array>& output_shapes) { + bool is_inplace = (call->op == Op::Get("relax.call_tir_inplace")); + size_t n = func->params.size(); + int num_inputs = Downcast(call->args[1])->fields.size(); size_t output_size = output_shapes.size(); ICHECK_GE(n, output_size); - // Allocate intermediate buffer - Array alloc_buffers; - Array output_params = GetPrimFuncOutputParams(func, output_size); + Array output_buffers; + Array output_idxs; + if (is_inplace) { + const auto* attrs = call->attrs.as(); + CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call"; + output_idxs = std::move(GetInplaceOutputIndices(attrs->inplace_indices, num_inputs)); + } else { + for (size_t i = 0; i < output_size; i++) { + output_idxs.push_back(num_inputs + i); + } + } + + Array output_params = GetPrimFuncOutputParams(func, output_idxs); + auto input_buffers = func_info_.expr2buffers.Get(call->args[1]); for (size_t i = 0; i < output_size; ++i) { const tir::Var& param = output_params[i]; const tir::Buffer& buffer = func->buffer_map.at(param); + // if this is an inplace output, do not do an intermediate allocation + if (output_idxs[i].IntValue() < num_inputs) { + CHECK(input_buffers.defined()) << "Inplace functions must have some defined input"; + output_buffers.push_back(input_buffers.value()[output_idxs[i].IntValue()]); + continue; + } + auto unify_name_hints = [this, &buffer]() { String base_name = buffer->name; String unique_name = base_name + "_intermediate"; @@ -703,14 +784,14 @@ class FusedTIRConstructor : public ExprVisitor { n->name = unify_name_hints(); tir::Buffer new_buffer(n); func_info_.alloc_buffers.push_back(new_buffer); - alloc_buffers.push_back(new_buffer); + output_buffers.push_back(new_buffer); // Match the shape of the output buffer with the shape func_info_.symbolic_var_matcher.Match(buffer->shape, n->shape); func_info_.buffer_subst_map.Set(buffer, new_buffer); } // Update expr2buffers - func_info_.expr2buffers.Set(expr, alloc_buffers); + func_info_.expr2buffers.Set(GetRef(call), output_buffers); } /*! @@ -858,6 +939,8 @@ class FusedTIRConstructor : public ExprVisitor { FuseFuncInfo func_info_; /*! \brief The tir function after fusion*/ tir::PrimFunc fused_tir_; + /*! \brief Indices of inputs that are used for in-place computation */ + std::unordered_set inplace_indices_; }; std::vector GetTupleAccessedIndices(const FunctionNode* func, const Var& tuple_var) { @@ -897,8 +980,11 @@ class TIRFuseMutator : public ExprMutator { for (const auto& [gv, func] : mod->functions) { // Only fuse primitive relax functions if (func->IsInstance() && func->HasNonzeroAttr(attr::kPrimitive)) { - tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv); - mutator.fused_tir_funcs_.Set(gv, fused_tir); + const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, gv); + mutator.fused_tir_funcs_.Set(gv, prim_func); + if (!indices.empty()) { + mutator.inplace_indices_.Set(gv, indices); + } } } @@ -945,6 +1031,7 @@ class TIRFuseMutator : public ExprMutator { Expr VisitExpr_(const CallNode* op) final { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); Call call = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(op))); @@ -985,26 +1072,34 @@ class TIRFuseMutator : public ExprMutator { CHECK(prim_value->value.defined()) << "FuseTIR requires all R.Prim arguments to have a known value."; PrimExpr expr = prim_value->value.value(); - CHECK(expr->IsInstance()) - << "FuseTIR currently requires all R.Prim arguments to provide a single tir::Var."; + CHECK(expr->IsInstance()) << "FuseTIR currently requires all R.Prim " + "arguments to provide a single tir::Var."; tir_vars.push_back(expr); } else { arg_list.push_back(arg); } } - // Step b. Create call_tir + // Step b. Create call_tir or call_tir_inplace Array call_args = {fused_tir_gv, Tuple(arg_list)}; if (!tir_vars.empty()) { call_args.push_back(ShapeExpr(tir_vars)); } - return Call(call_tir_op_, call_args, call->attrs, {GetStructInfo(call)}); + Op call_op = call_tir_op_; + Attrs call_attrs = call->attrs; + if (auto it = inplace_indices_.find(old_gv); it != inplace_indices_.end()) { + call_op = call_tir_inplace_op_; + auto inplace_attrs = make_object(); + inplace_attrs->inplace_indices = (*it).second; + call_attrs = Attrs(inplace_attrs); + } + return Call(call_op, call_args, call_attrs, {GetStructInfo(call)}); } else { // Case 1.2. The callee function is not primitive, nothing to do. return call; } - } else if (call->op == call_tir_op_) { - // Case 2. It is a call_tir, re-emit the PrimFunc. + } else if (call->op == call_tir_op_ || call->op == call_tir_inplace_op_) { + // Case 2. It is a call_tir or call_tir_inplace, re-emit the PrimFunc. if (const auto* gv = call->args[0].as()) { tir::PrimFunc func = Downcast(mod_->Lookup(GetRef(gv))); GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint); @@ -1023,6 +1118,9 @@ class TIRFuseMutator : public ExprMutator { const IRModule& mod_; /*! \brief The map from global var of primitive relax function to generated prim func. */ Map fused_tir_funcs_; + /*! \brief The map from global var of primitive relax function to in-place indices + * (if there are any). */ + Map> inplace_indices_; }; IRModule FuseTIR(IRModule mod) { diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 1a4a630e3e5a..3cd608d8ee8f 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -1501,5 +1501,146 @@ def main( _check(Module, Expected) +def test_call_tir_inplace(): + @I.ir_module + class Module: + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(10), T.int64(20)), "float32"), + B: T.Buffer((), "float32"), + Out: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[()]) + T.writes(Out[v_ax0, v_ax1]) + Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] + + @T.prim_func(private=True) + def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1]) + T.writes(A[v_i0, v_i1]) + A[v_i0, v_i1] = T.exp(A[v_i0, v_i1]) + + @T.prim_func(private=True) + def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Module + with R.dataflow(): + lv = R.call_tir( + cls.add, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + lv1 = R.call_tir_inplace( + cls.exp_inplace, + (lv,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + gv = R.call_tir_inplace( + cls.squeeze_inplace, + (lv1,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(10), T.int64(20)), "float32"), + B: T.Buffer((), "float32"), + Out: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True), "op_pattern": 0}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[()]) + T.writes(Out[v_ax0, v_ax1]) + Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] + + @T.prim_func(private=True) + def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True), "op_pattern": 0}) + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1]) + T.writes(A[v_i0, v_i1]) + A[v_i0, v_i1] = T.exp(A[v_i0, v_i1]) + + @T.prim_func(private=True) + def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True), "op_pattern": 0}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + + @R.function(private=True) + def fused_add_exp_inplace_squeeze_inplace( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Expected + with R.dataflow(): + lv = R.call_tir( + cls.add, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + lv1 = R.call_tir_inplace( + cls.exp_inplace, + (lv,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + gv = R.call_tir_inplace( + cls.squeeze_inplace, + (lv1,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv) + return gv + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Expected + with R.dataflow(): + gv1: R.Tensor( + (10, 20), dtype="float32" + ) = cls.fused_add_exp_inplace_squeeze_inplace(x, p0) + R.output(gv1) + return gv1 + + _check(Module, Expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 143670c70180..c0a6f4448b5c 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -1930,5 +1930,329 @@ def main( _check(Before, After) +def test_inplace_simple(): + @I.ir_module + class Module: + I.module_attrs({"foo": "bar"}) + + @T.prim_func(private=True) + def add_inplace( + A: T.Buffer((T.int64(10), T.int64(20)), "float32"), B: T.Buffer((), "float32") + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + # T.reads(A[v_ax0, v_ax1], B[()]) + # T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] + + @T.prim_func(private=True) + def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + # T.reads(A[v_i0, v_i1]) + # T.writes(A[v_i0, v_i1]) + A[v_i0, v_i1] = T.exp(A[v_i0, v_i1]) + + @T.prim_func(private=True) + def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + # T.reads(A[v_ax0, v_ax1]) + # T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + + @R.function(private=True) + def fused_add_exp_squeeze( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Module + with R.dataflow(): + # This overwrites x and is actually evil because the function is marked as pure + # but we are doing it just to test the pass. The automatic DataflowUseInplaceCalls + # transformation will not produce code like this, but it may make sense to do it + # if ownership of x is fully and truly transferred. + # Users should apply with caution! + lv = R.call_tir_inplace( + cls.add_inplace, + (x, p0), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + lv1 = R.call_tir_inplace( + cls.exp_inplace, + (lv,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + gv = R.call_tir_inplace( + cls.squeeze_inplace, + (lv1,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv) + return gv + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Module + with R.dataflow(): + gv1: R.Tensor((10, 20), dtype="float32") = cls.fused_add_exp_squeeze(x, p0) + R.output(gv1) + return gv1 + + @I.ir_module + class Expected: + I.module_attrs({"foo": "bar"}) + + @T.prim_func(private=True) + def fused_add_exp_squeeze( + x: T.Buffer((T.int64(10), T.int64(20)), "float32"), p0: T.Buffer((), "float32") + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + x[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + x[v_i0, v_i1] = T.exp(x[v_i0, v_i1]) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + x[v_ax0, v_ax1] = x[v_ax0, v_ax1] + + # note that this will clobber x! Use with caution + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Expected + with R.dataflow(): + gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir_inplace( + cls.fused_add_exp_squeeze, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + inplace_indices=[0], + ) + R.output(gv1) + return gv1 + + _check(Module, Expected) + + +def test_fuse_inplace_and_non_inplace(): + @I.ir_module + class Module: + I.module_attrs({"foo": "bar"}) + + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(10), T.int64(20)), "float32"), + B: T.Buffer((), "float32"), + Out: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] + + @T.prim_func(private=True) + def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + A[v_i0, v_i1] = T.exp(A[v_i0, v_i1]) + + @T.prim_func(private=True) + def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + + @R.function(private=True) + def fused_add_exp_squeeze( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Module + with R.dataflow(): + lv = R.call_tir( + cls.add, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + lv1 = R.call_tir_inplace( + cls.exp_inplace, + (lv,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + gv = R.call_tir_inplace( + cls.squeeze_inplace, + (lv1,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv) + return gv + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Module + with R.dataflow(): + gv1: R.Tensor((10, 20), dtype="float32") = cls.fused_add_exp_squeeze(x, p0) + R.output(gv1) + return gv1 + + @I.ir_module + class Expected: + I.module_attrs({"foo": "bar"}) + + @T.prim_func(private=True) + def fused_add_exp_squeeze( + x: T.Buffer((T.int64(10), T.int64(20)), "float32"), + p0: T.Buffer((), "float32"), + p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + p_output0[v_i0, v_i1] = T.exp(p_output0[v_i0, v_i1]) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + p_output0[v_ax0, v_ax1] = p_output0[v_ax0, v_ax1] + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Expected + with R.dataflow(): + gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir( + cls.fused_add_exp_squeeze, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv1) + return gv1 + + _check(Module, Expected) + + +def test_use_as_inplace_and_dps(): + @I.ir_module + class Module: + # we will use it both in-place and normally (DPS) + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(10), T.int64(20)), "float32"), + B: T.Buffer((), "float32"), + Out: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] + + @R.function(private=True) + def fused_sums( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Module + with R.dataflow(): + lv = R.call_tir( + cls.add, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + lv1 = R.call_tir_inplace( + cls.add, + (x, p0, lv), + inplace_indices=[2], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + lv2 = R.call_tir_inplace( + cls.add, + (x, p0, lv1), + inplace_indices=[2], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(lv2) + return lv2 + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Module + with R.dataflow(): + gv1: R.Tensor((10, 20), dtype="float32") = cls.fused_sums(x, p0) + R.output(gv1) + return gv1 + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def fused_sums( + x: T.Buffer((T.int64(10), T.int64(20)), "float32"), + p0: T.Buffer((), "float32"), + p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Expected + with R.dataflow(): + gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir( + cls.fused_sums, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv1) + return gv1 + + _check(Module, Expected) + + if __name__ == "__main__": tvm.testing.main()