From bfa28f55416b4424ad45bc15c07b7c7d278ab9a5 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 23 Jan 2024 23:24:10 -0500 Subject: [PATCH 1/8] WIP initial commit --- src/relax/transform/fuse_ops.cc | 10 +- src/relax/transform/fuse_tir.cc | 86 +++++++++--- tests/python/relax/test_transform_fuse_tir.py | 130 ++++++++++++++++++ 3 files changed, 205 insertions(+), 21 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index b0eeba399e90..32780f6dd253 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -183,6 +183,8 @@ class GraphCreator : public ExprVisitor { ICHECK_NOTNULL(binding_var_node); static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); + OpPatternKind pattern = OpPatternKind::kOpaque; Array args = call->args; @@ -191,7 +193,7 @@ class GraphCreator : public ExprVisitor { // - Otherwise, the pattern of the current binding variable node is set to `kOpaque`, and we // recurse into the call expression. const auto* op = call->op.as(); - if (op == call_tir_op_.get()) { + if (op == call_tir_op_.get() || op == call_tir_inplace_op_.get()) { const GlobalVar& global_var = Downcast(call->args[0]); tir::PrimFunc func = Downcast(mod_->Lookup(global_var)); @@ -377,7 +379,8 @@ class FunctionCreator : public ExprMutator { * function accordingly * \param binding The binding to be appended * \note Allowed bindings are: - * - VarBinding with value being a call node calling `relax.call_tir`. + * - VarBinding with value being a call node calling `relax.call_tir` or + * `relax.call_tir_inplace`. * - VarBinding with value being a tuple-get-item node. * // TODO(tvm-team): handle match shape */ @@ -387,7 +390,8 @@ class FunctionCreator : public ExprMutator { if (const auto* var_binding = binding.as()) { if (const auto* call = var_binding->value.as()) { - if (call->op == Op::Get("relax.call_tir")) { + if (call->op == Op::Get("relax.call_tir") || + call->op == Op::Get("relax.call_tir_inplace")) { // Update the name of the function. name_hint_ = name_hint_ + "_" + Downcast(call->args[0])->name_hint; diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 1c25229d88f8..8c277a1d72ec 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include #include @@ -476,8 +477,11 @@ class FusedTIRConstructor : public ExprVisitor { void VisitExpr_(const CallNode* call) final { ExprVisitor::VisitExpr_(call); static const Op& call_tir_op_ = Op::Get("relax.call_tir"); - ICHECK(call->op == call_tir_op_) - << "Only call_tir is supported in primitive function, but got: " << GetRef(call); + static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); + + ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_) + << "Only call_tir and call_tir_inplace are supported in primitive function, but got: " + << GetRef(call); // Step 1. Get Global var and PrimFunc GlobalVar gv = Downcast(call->args[0]); @@ -503,7 +507,9 @@ class FusedTIRConstructor : public ExprVisitor { MapInputBuffer(prim_func, call->args[1]); const Array>& output_buffer_shapes = GetCallTIROutputShapes(call); - AllocateIntermediateBuffer(GetRef(call), prim_func, output_buffer_shapes); + // TODO(@tvm-team): We should be able to avoid intermediate allocations for in-place calls. + // Currently the same logic is done for in-place calls to avoid crashes or garbage output. + AllocateIntermediateBuffer(call, prim_func, output_buffer_shapes); // Step 6. Update tir_vars if (call->args.size() > 2) { @@ -566,7 +572,8 @@ class FusedTIRConstructor : public ExprVisitor { */ static Array> GetCallTIROutputShapes(const CallNode* call) { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); - ICHECK(call->op.same_as(call_tir_op_)); + static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); + ICHECK(call->op.same_as(call_tir_op_) || call->op.same_as(call_tir_inplace_op_)); ICHECK_EQ(call->sinfo_args.size(), 1); auto get_tensor_shape = [](const TensorStructInfoNode* sinfo) { const auto* shape_expr = sinfo->shape.as(); @@ -639,28 +646,56 @@ class FusedTIRConstructor : public ExprVisitor { MapArgsToBuffer(arg_list, buffer_list); } - static Array GetPrimFuncOutputParams(const tir::PrimFunc& func, size_t output_size) { + static Array GetInplaceOutputIndices(const Array& inplace_indices) { + Array ret; + int num_non_negative = 0; + for (auto idx : inplace_indices) { + int i = idx.IntValue(); + if (i > 0) { + num_non_negative++; + } + } + + int negative_idx_translation = num_non_negative; + for (auto idx : inplace_indices) { + int i = idx.IntValue(); + if (i > 0) { + ret.push_back(Integer(i)); + } else { + ret.push_back(Integer(negative_idx_translation)); + negative_idx_translation++; + } + } + + return ret; + } + + static Array GetPrimFuncOutputParams(const tir::PrimFunc& func, + const Array& output_indices) { size_t n = func->params.size(); int symbolic_var_index = -1; + size_t output_size = output_indices.size(); ICHECK_GE(n, output_size); - for (size_t i = 0; i < n; ++i) { - const tir::Var& param = func->params[i]; + + Array ret; + for (auto idx : output_indices) { + int i = idx.IntValue(); + const tir::Var& param = func->params[static_cast(i)]; if (param->dtype.is_int() || param->dtype.is_uint()) { if (symbolic_var_index == -1) symbolic_var_index = i; } else if (param->dtype.is_handle()) { CHECK(symbolic_var_index == -1) << "The scalar input should be at the ending of the " "parameter list."; + ret.push_back(param); } else { LOG(FATAL) << "The params of PrimFunc are expected to be Buffer handle or scalar, but got: " << param->dtype; } } + size_t end_index = symbolic_var_index == -1 ? n : symbolic_var_index; ICHECK_GE(end_index, output_size); - size_t begin_index = end_index - output_size; - Array output_params{func->params.begin() + begin_index, - func->params.begin() + end_index}; - return output_params; + return ret; } /*! @@ -670,14 +705,28 @@ class FusedTIRConstructor : public ExprVisitor { * \param func The old TIR PrimFunc * \param output_shapes The shape of output params. */ - void AllocateIntermediateBuffer(const Expr& expr, const tir::PrimFunc& func, + void AllocateIntermediateBuffer(const CallNode* call, const tir::PrimFunc& func, const Array>& output_shapes) { + bool is_inplace = (call->op == Op::Get("relax.call_tir_inplace")); + size_t n = func->params.size(); size_t output_size = output_shapes.size(); ICHECK_GE(n, output_size); // Allocate intermediate buffer Array alloc_buffers; - Array output_params = GetPrimFuncOutputParams(func, output_size); + Array output_idxs; + if (is_inplace) { + const auto* attrs = call->attrs.as(); + CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call"; + output_idxs = std::move(GetInplaceOutputIndices(attrs->inplace_indices)); + } else { + int num_inputs = Downcast(call->args[1])->fields.size(); + for (size_t i = 0; i < output_size; i++) { + output_idxs.push_back(num_inputs + i); + } + } + + Array output_params = GetPrimFuncOutputParams(func, output_idxs); for (size_t i = 0; i < output_size; ++i) { const tir::Var& param = output_params[i]; const tir::Buffer& buffer = func->buffer_map.at(param); @@ -710,7 +759,7 @@ class FusedTIRConstructor : public ExprVisitor { func_info_.buffer_subst_map.Set(buffer, new_buffer); } // Update expr2buffers - func_info_.expr2buffers.Set(expr, alloc_buffers); + func_info_.expr2buffers.Set(GetRef(call), alloc_buffers); } /*! @@ -945,6 +994,7 @@ class TIRFuseMutator : public ExprMutator { Expr VisitExpr_(const CallNode* op) final { static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); Call call = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(op))); @@ -985,8 +1035,8 @@ class TIRFuseMutator : public ExprMutator { CHECK(prim_value->value.defined()) << "FuseTIR requires all R.Prim arguments to have a known value."; PrimExpr expr = prim_value->value.value(); - CHECK(expr->IsInstance()) - << "FuseTIR currently requires all R.Prim arguments to provide a single tir::Var."; + CHECK(expr->IsInstance()) << "FuseTIR currently requires all R.Prim " + "arguments to provide a single tir::Var."; tir_vars.push_back(expr); } else { @@ -1003,8 +1053,8 @@ class TIRFuseMutator : public ExprMutator { // Case 1.2. The callee function is not primitive, nothing to do. return call; } - } else if (call->op == call_tir_op_) { - // Case 2. It is a call_tir, re-emit the PrimFunc. + } else if (call->op == call_tir_op_ || call->op == call_tir_inplace_op_) { + // Case 2. It is a call_tir or call_tir_inplace, re-emit the PrimFunc. if (const auto* gv = call->args[0].as()) { tir::PrimFunc func = Downcast(mod_->Lookup(GetRef(gv))); GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint); diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 143670c70180..d0a6353079ab 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -1930,5 +1930,135 @@ def main( _check(Before, After) +def test_inplace_simple(): + @I.ir_module + class Module: + I.module_attrs({"foo": "bar"}) + + @T.prim_func(private=True) + def add_inplace( + A: T.Buffer((T.int64(10), T.int64(20)), "float32"), B: T.Buffer((), "float32") + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[()]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] + + @T.prim_func(private=True) + def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1]) + T.writes(A[v_i0, v_i1]) + A[v_i0, v_i1] = T.exp(A[v_i0, v_i1]) + + @T.prim_func(private=True) + def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + + @R.function(private=True) + def fused_add_exp_squeeze( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Module + with R.dataflow(): + # this overwrites x and is actually evil but we are doing it just to test the pass + lv = R.call_tir_inplace( + cls.add_inplace, + (x, p0), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + lv1 = R.call_tir_inplace( + cls.exp_inplace, + (lv,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + gv = R.call_tir_inplace( + cls.squeeze_inplace, + (lv1,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv) + return gv + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Module + with R.dataflow(): + gv1: R.Tensor((10, 20), dtype="float32") = cls.fused_add_exp_squeeze(x, p0) + R.output(gv1) + return gv1 + + @I.ir_module + class Expected: + I.module_attrs({"foo": "bar"}) + + @T.prim_func(private=True) + def fused_add_exp_squeeze( + x: T.Buffer((T.int64(10), T.int64(20)), "float32"), + p0: T.Buffer((), "float32"), + A_intermediate_1_2: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # TODO(@tvm-team): This is a temporary measure to avoid crashes when dealing with + # in-place calls. In reality, we should need intermediate allocations for + # in-place outputs. + A_intermediate = T.alloc_buffer((T.int64(10), T.int64(20))) + A_intermediate_1 = T.alloc_buffer((T.int64(10), T.int64(20))) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A_intermediate[v_ax0, v_ax1], p0[()]) + T.writes(A_intermediate[v_ax0, v_ax1]) + A_intermediate[v_ax0, v_ax1] = A_intermediate[v_ax0, v_ax1] + p0[()] + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A_intermediate_1[v_i0, v_i1]) + T.writes(A_intermediate_1[v_i0, v_i1]) + A_intermediate_1[v_i0, v_i1] = T.exp(A_intermediate_1[v_i0, v_i1]) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A_intermediate_1_2[v_ax0, v_ax1]) + T.writes(A_intermediate_1_2[v_ax0, v_ax1]) + A_intermediate_1_2[v_ax0, v_ax1] = A_intermediate_1_2[v_ax0, v_ax1] + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Module + with R.dataflow(): + gv1 = R.call_tir( + cls.fused_add_exp_squeeze, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv1) + return gv1 + + mod_after = relax.transform.FuseTIR()(Module) + print(mod_after) + assert False + + if __name__ == "__main__": tvm.testing.main() From cac58caba21a0999117af476bddbf33e09d6bd54 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 29 Jan 2024 22:07:27 -0500 Subject: [PATCH 2/8] Handle in-place calls in FuseTIR --- src/relax/transform/fuse_tir.cc | 101 +++++++---- tests/python/relax/test_transform_fuse_tir.py | 158 +++++++++++++++--- 2 files changed, 208 insertions(+), 51 deletions(-) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 8c277a1d72ec..7d3ac53795f4 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -368,9 +368,10 @@ class FusedTIRConstructor : public ExprVisitor { * \brief Construct a fused TIR PrimFunc from a relax sub-function * \param mod The IRModule * \param gv The global var of relax subfunction to be fused into one PrimFunc - * \return The fused TIR PrimFunc + * \return The fused TIR PrimFunc and the in-place indices (non-empty for an in-place call) */ - static tir::PrimFunc GetFusedTIR(const IRModule& mod, const GlobalVar& gv) { + static std::pair> GetFusedTIR(const IRModule& mod, + const GlobalVar& gv) { FusedTIRConstructor visitor(mod, gv->name_hint); BaseFunc f = mod->Lookup(gv); CHECK(f->IsInstance()) @@ -378,7 +379,7 @@ class FusedTIRConstructor : public ExprVisitor { CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive)) << "Expected a function with attr `kPrimitive`"; visitor(Downcast(f)); - return visitor.fused_tir_; + return {visitor.fused_tir_, visitor.inplace_indices_}; } private: @@ -439,9 +440,35 @@ class FusedTIRConstructor : public ExprVisitor { auto it = func_info_.expr2buffers.find(body); ICHECK(it != func_info_.expr2buffers.end()) << "Fail to detect output buffers for function body"; + const Array& buffers = (*it).second; + + // map of input buffers to indices (helpful for detecting in-place inputs) + std::unordered_map buffer_to_idx; + std::unordered_map input_to_idx; + for (size_t i = 0; i < func_info_.params.size(); i++) { + input_to_idx[func_info_.params[i]] = Integer(i); + } + for (auto kv : func_info_.buffer_map) { + if (input_to_idx.count(kv.first)) { + buffer_to_idx[kv.second] = input_to_idx[kv.first]; + } + } + + // numbered separately because the number of output *vars* might differ from the + // number of outputs if there are in-place inputs + int out_idx = 0; for (size_t i = 0; i < buffers.size(); ++i) { - tir::Var param = tir::Var("p_output" + std::to_string(i), PrimType(DataType::Handle())); + // Do not add output vars for in-place inputs + // (i.e., already listed in the buffer map. This would result + // in duplicates in the buffer map otherwise) + if (buffer_to_idx.count(buffers[i])) { + inplace_indices_.push_back(buffer_to_idx[buffers[i]]); + continue; + } + + tir::Var param = tir::Var("p_output" + std::to_string(out_idx), PrimType(DataType::Handle())); + out_idx++; func_info_.buffer_map.Set(param, buffers[i]); func_info_.params.push_back(param); func_info_.output_buffers.insert(buffers[i].get()); @@ -507,8 +534,6 @@ class FusedTIRConstructor : public ExprVisitor { MapInputBuffer(prim_func, call->args[1]); const Array>& output_buffer_shapes = GetCallTIROutputShapes(call); - // TODO(@tvm-team): We should be able to avoid intermediate allocations for in-place calls. - // Currently the same logic is done for in-place calls to avoid crashes or garbage output. AllocateIntermediateBuffer(call, prim_func, output_buffer_shapes); // Step 6. Update tir_vars @@ -618,7 +643,7 @@ class FusedTIRConstructor : public ExprVisitor { } } } - // Make sure every buffers are mapped. + // Make sure every buffer is mapped. ICHECK_EQ(buffer_idx, buffers.size()); } @@ -646,24 +671,17 @@ class FusedTIRConstructor : public ExprVisitor { MapArgsToBuffer(arg_list, buffer_list); } - static Array GetInplaceOutputIndices(const Array& inplace_indices) { + static Array GetInplaceOutputIndices(const Array& inplace_indices, + int num_inputs) { Array ret; - int num_non_negative = 0; - for (auto idx : inplace_indices) { - int i = idx.IntValue(); - if (i > 0) { - num_non_negative++; - } - } - - int negative_idx_translation = num_non_negative; + int last_idx = num_inputs; for (auto idx : inplace_indices) { int i = idx.IntValue(); - if (i > 0) { + if (i >= 0) { ret.push_back(Integer(i)); } else { - ret.push_back(Integer(negative_idx_translation)); - negative_idx_translation++; + ret.push_back(Integer(last_idx)); + last_idx++; } } @@ -710,27 +728,34 @@ class FusedTIRConstructor : public ExprVisitor { bool is_inplace = (call->op == Op::Get("relax.call_tir_inplace")); size_t n = func->params.size(); + int num_inputs = Downcast(call->args[1])->fields.size(); size_t output_size = output_shapes.size(); ICHECK_GE(n, output_size); - // Allocate intermediate buffer - Array alloc_buffers; + Array output_buffers; Array output_idxs; if (is_inplace) { const auto* attrs = call->attrs.as(); CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call"; - output_idxs = std::move(GetInplaceOutputIndices(attrs->inplace_indices)); + output_idxs = std::move(GetInplaceOutputIndices(attrs->inplace_indices, num_inputs)); } else { - int num_inputs = Downcast(call->args[1])->fields.size(); for (size_t i = 0; i < output_size; i++) { output_idxs.push_back(num_inputs + i); } } Array output_params = GetPrimFuncOutputParams(func, output_idxs); + auto input_buffers = func_info_.expr2buffers.Get(call->args[1]); for (size_t i = 0; i < output_size; ++i) { const tir::Var& param = output_params[i]; const tir::Buffer& buffer = func->buffer_map.at(param); + // if this is an inplace output, do not do an intermediate allocation + if (output_idxs[i].IntValue() < num_inputs) { + CHECK(input_buffers.defined()) << "Inplace functions must have some defined input"; + output_buffers.push_back(input_buffers.value()[output_idxs[i].IntValue()]); + continue; + } + auto unify_name_hints = [this, &buffer]() { String base_name = buffer->name; String unique_name = base_name + "_intermediate"; @@ -752,14 +777,14 @@ class FusedTIRConstructor : public ExprVisitor { n->name = unify_name_hints(); tir::Buffer new_buffer(n); func_info_.alloc_buffers.push_back(new_buffer); - alloc_buffers.push_back(new_buffer); + output_buffers.push_back(new_buffer); // Match the shape of the output buffer with the shape func_info_.symbolic_var_matcher.Match(buffer->shape, n->shape); func_info_.buffer_subst_map.Set(buffer, new_buffer); } // Update expr2buffers - func_info_.expr2buffers.Set(GetRef(call), alloc_buffers); + func_info_.expr2buffers.Set(GetRef(call), output_buffers); } /*! @@ -907,6 +932,8 @@ class FusedTIRConstructor : public ExprVisitor { FuseFuncInfo func_info_; /*! \brief The tir function after fusion*/ tir::PrimFunc fused_tir_; + /*! \brief Indices of inputs that are used for in-place computation */ + Array inplace_indices_; }; std::vector GetTupleAccessedIndices(const FunctionNode* func, const Var& tuple_var) { @@ -946,8 +973,11 @@ class TIRFuseMutator : public ExprMutator { for (const auto& [gv, func] : mod->functions) { // Only fuse primitive relax functions if (func->IsInstance() && func->HasNonzeroAttr(attr::kPrimitive)) { - tir::PrimFunc fused_tir = FusedTIRConstructor::GetFusedTIR(mod, gv); - mutator.fused_tir_funcs_.Set(gv, fused_tir); + const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, gv); + mutator.fused_tir_funcs_.Set(gv, prim_func); + if (!indices.empty()) { + mutator.inplace_indices_.Set(gv, indices); + } } } @@ -1043,12 +1073,20 @@ class TIRFuseMutator : public ExprMutator { arg_list.push_back(arg); } } - // Step b. Create call_tir + // Step b. Create call_tir or call_tir_inplace Array call_args = {fused_tir_gv, Tuple(arg_list)}; if (!tir_vars.empty()) { call_args.push_back(ShapeExpr(tir_vars)); } - return Call(call_tir_op_, call_args, call->attrs, {GetStructInfo(call)}); + Op call_op = call_tir_op_; + Attrs call_attrs = call->attrs; + if (inplace_indices_.count(old_gv)) { + call_op = call_tir_inplace_op_; + auto inplace_attrs = make_object(); + inplace_attrs->inplace_indices = inplace_indices_.at(old_gv); + call_attrs = Attrs(inplace_attrs); + } + return Call(call_op, call_args, call_attrs, {GetStructInfo(call)}); } else { // Case 1.2. The callee function is not primitive, nothing to do. return call; @@ -1073,6 +1111,9 @@ class TIRFuseMutator : public ExprMutator { const IRModule& mod_; /*! \brief The map from global var of primitive relax function to generated prim func. */ Map fused_tir_funcs_; + /*! \brief The map from global var of primitive relax function to in-place indices + * (if there are any). */ + Map> inplace_indices_; }; IRModule FuseTIR(IRModule mod) { diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index d0a6353079ab..7544260081ce 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -2006,6 +2006,129 @@ def main( R.output(gv1) return gv1 + @I.ir_module + class Expected: + I.module_attrs({"foo": "bar"}) + + @T.prim_func(private=True) + def fused_add_exp_squeeze( + x: T.Buffer((T.int64(10), T.int64(20)), "float32"), p0: T.Buffer((), "float32") + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x[v_ax0, v_ax1], p0[()]) + T.writes(x[v_ax0, v_ax1]) + x[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(x[v_i0, v_i1]) + T.writes(x[v_i0, v_i1]) + x[v_i0, v_i1] = T.exp(x[v_i0, v_i1]) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(x[v_ax0, v_ax1]) + T.writes(x[v_ax0, v_ax1]) + x[v_ax0, v_ax1] = x[v_ax0, v_ax1] + + # note that this will clobber x! Use with caution + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Expected + with R.dataflow(): + gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir_inplace( + cls.fused_add_exp_squeeze, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + inplace_indices=[0], + ) + R.output(gv1) + return gv1 + + _check(Module, Expected) + + +def test_fuse_inplace_and_non_inplace(): + @I.ir_module + class Module: + I.module_attrs({"foo": "bar"}) + + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(10), T.int64(20)), "float32"), + B: T.Buffer((), "float32"), + Out: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[()]) + T.writes(Out[v_ax0, v_ax1]) + Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] + + @T.prim_func(private=True) + def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1]) + T.writes(A[v_i0, v_i1]) + A[v_i0, v_i1] = T.exp(A[v_i0, v_i1]) + + @T.prim_func(private=True) + def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + + @R.function(private=True) + def fused_add_exp_squeeze( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Module + with R.dataflow(): + lv = R.call_tir( + cls.add, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + lv1 = R.call_tir_inplace( + cls.exp_inplace, + (lv,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + gv = R.call_tir_inplace( + cls.squeeze_inplace, + (lv1,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv) + return gv + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Module + with R.dataflow(): + gv1: R.Tensor((10, 20), dtype="float32") = cls.fused_add_exp_squeeze(x, p0) + R.output(gv1) + return gv1 + @I.ir_module class Expected: I.module_attrs({"foo": "bar"}) @@ -2014,50 +2137,43 @@ class Expected: def fused_add_exp_squeeze( x: T.Buffer((T.int64(10), T.int64(20)), "float32"), p0: T.Buffer((), "float32"), - A_intermediate_1_2: T.Buffer((T.int64(10), T.int64(20)), "float32"), + p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"), ): T.func_attr({"tir.noalias": T.bool(True)}) - # TODO(@tvm-team): This is a temporary measure to avoid crashes when dealing with - # in-place calls. In reality, we should need intermediate allocations for - # in-place outputs. - A_intermediate = T.alloc_buffer((T.int64(10), T.int64(20))) - A_intermediate_1 = T.alloc_buffer((T.int64(10), T.int64(20))) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.block("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A_intermediate[v_ax0, v_ax1], p0[()]) - T.writes(A_intermediate[v_ax0, v_ax1]) - A_intermediate[v_ax0, v_ax1] = A_intermediate[v_ax0, v_ax1] + p0[()] + T.reads(x[v_ax0, v_ax1], p0[()]) + T.writes(p_output0[v_ax0, v_ax1]) + p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] for i0, i1 in T.grid(T.int64(10), T.int64(20)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(A_intermediate_1[v_i0, v_i1]) - T.writes(A_intermediate_1[v_i0, v_i1]) - A_intermediate_1[v_i0, v_i1] = T.exp(A_intermediate_1[v_i0, v_i1]) + T.reads(p_output0[v_i0, v_i1]) + T.writes(p_output0[v_i0, v_i1]) + p_output0[v_i0, v_i1] = T.exp(p_output0[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.block("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A_intermediate_1_2[v_ax0, v_ax1]) - T.writes(A_intermediate_1_2[v_ax0, v_ax1]) - A_intermediate_1_2[v_ax0, v_ax1] = A_intermediate_1_2[v_ax0, v_ax1] + T.reads(p_output0[v_ax0, v_ax1]) + T.writes(p_output0[v_ax0, v_ax1]) + p_output0[v_ax0, v_ax1] = p_output0[v_ax0, v_ax1] @R.function def main( x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") ) -> R.Tensor((10, 20), dtype="float32"): - cls = Module + cls = Expected with R.dataflow(): - gv1 = R.call_tir( + gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir( cls.fused_add_exp_squeeze, (x, p0), - out_sinfo=R.Tensor((10, 20), dtype="float32"), + out_sinfo=R.Tensor((10, 20), dtype="float32") ) R.output(gv1) return gv1 - mod_after = relax.transform.FuseTIR()(Module) - print(mod_after) - assert False + _check(Module, Expected) if __name__ == "__main__": From 2067a86714d4b55d110698448079026eb51199e7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 29 Jan 2024 22:15:55 -0500 Subject: [PATCH 3/8] Formatting --- tests/python/relax/test_transform_fuse_tir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 7544260081ce..33419d95b7a5 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -2168,7 +2168,7 @@ def main( gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir( cls.fused_add_exp_squeeze, (x, p0), - out_sinfo=R.Tensor((10, 20), dtype="float32") + out_sinfo=R.Tensor((10, 20), dtype="float32"), ) R.output(gv1) return gv1 From 9fae00b9eb01a049e212902019ce7457d69a8fcb Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 29 Jan 2024 22:25:53 -0500 Subject: [PATCH 4/8] Add test case for FuseOps --- tests/python/relax/test_transform_fuse_ops.py | 141 ++++++++++++++++++ 1 file changed, 141 insertions(+) diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 1a4a630e3e5a..3cd608d8ee8f 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -1501,5 +1501,146 @@ def main( _check(Module, Expected) +def test_call_tir_inplace(): + @I.ir_module + class Module: + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(10), T.int64(20)), "float32"), + B: T.Buffer((), "float32"), + Out: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[()]) + T.writes(Out[v_ax0, v_ax1]) + Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] + + @T.prim_func(private=True) + def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1]) + T.writes(A[v_i0, v_i1]) + A[v_i0, v_i1] = T.exp(A[v_i0, v_i1]) + + @T.prim_func(private=True) + def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Module + with R.dataflow(): + lv = R.call_tir( + cls.add, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + lv1 = R.call_tir_inplace( + cls.exp_inplace, + (lv,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + gv = R.call_tir_inplace( + cls.squeeze_inplace, + (lv1,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(10), T.int64(20)), "float32"), + B: T.Buffer((), "float32"), + Out: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True), "op_pattern": 0}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1], B[()]) + T.writes(Out[v_ax0, v_ax1]) + Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] + + @T.prim_func(private=True) + def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True), "op_pattern": 0}) + for i0, i1 in T.grid(T.int64(10), T.int64(20)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(A[v_i0, v_i1]) + T.writes(A[v_i0, v_i1]) + A[v_i0, v_i1] = T.exp(A[v_i0, v_i1]) + + @T.prim_func(private=True) + def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): + T.func_attr({"tir.noalias": T.bool(True), "op_pattern": 0}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_squeeze"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[v_ax0, v_ax1]) + T.writes(A[v_ax0, v_ax1]) + A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + + @R.function(private=True) + def fused_add_exp_inplace_squeeze_inplace( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Expected + with R.dataflow(): + lv = R.call_tir( + cls.add, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + lv1 = R.call_tir_inplace( + cls.exp_inplace, + (lv,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + gv = R.call_tir_inplace( + cls.squeeze_inplace, + (lv1,), + inplace_indices=[0], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv) + return gv + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Expected + with R.dataflow(): + gv1: R.Tensor( + (10, 20), dtype="float32" + ) = cls.fused_add_exp_inplace_squeeze_inplace(x, p0) + R.output(gv1) + return gv1 + + _check(Module, Expected) + + if __name__ == "__main__": tvm.testing.main() From ee1ed138af3744bedc4312996b811d816007ef24 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 30 Jan 2024 18:25:23 -0500 Subject: [PATCH 5/8] Address review comments related to clarity --- src/relax/transform/fuse_tir.cc | 14 ++++---- tests/python/relax/test_transform_fuse_tir.py | 36 ++++++------------- 2 files changed, 18 insertions(+), 32 deletions(-) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 7d3ac53795f4..828e0406f7b4 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -449,9 +449,9 @@ class FusedTIRConstructor : public ExprVisitor { for (size_t i = 0; i < func_info_.params.size(); i++) { input_to_idx[func_info_.params[i]] = Integer(i); } - for (auto kv : func_info_.buffer_map) { - if (input_to_idx.count(kv.first)) { - buffer_to_idx[kv.second] = input_to_idx[kv.first]; + for (auto [var, buffer] : func_info_.buffer_map) { + if (auto it = input_to_idx.find(var); it != input_to_idx.end()) { + buffer_to_idx[buffer] = (*it).second; } } @@ -462,8 +462,8 @@ class FusedTIRConstructor : public ExprVisitor { // Do not add output vars for in-place inputs // (i.e., already listed in the buffer map. This would result // in duplicates in the buffer map otherwise) - if (buffer_to_idx.count(buffers[i])) { - inplace_indices_.push_back(buffer_to_idx[buffers[i]]); + if (auto it = buffer_to_idx.find(buffers[i]); it != buffer_to_idx.end()) { + inplace_indices_.push_back((*it).second); continue; } @@ -1080,10 +1080,10 @@ class TIRFuseMutator : public ExprMutator { } Op call_op = call_tir_op_; Attrs call_attrs = call->attrs; - if (inplace_indices_.count(old_gv)) { + if (auto it = inplace_indices_.find(old_gv); it != inplace_indices_.end()) { call_op = call_tir_inplace_op_; auto inplace_attrs = make_object(); - inplace_attrs->inplace_indices = inplace_indices_.at(old_gv); + inplace_attrs->inplace_indices = (*it).second; call_attrs = Attrs(inplace_attrs); } return Call(call_op, call_args, call_attrs, {GetStructInfo(call)}); diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 33419d95b7a5..0113a4711cd1 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -1943,8 +1943,8 @@ def add_inplace( for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.block("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A[v_ax0, v_ax1], B[()]) - T.writes(A[v_ax0, v_ax1]) + # T.reads(A[v_ax0, v_ax1], B[()]) + # T.writes(A[v_ax0, v_ax1]) A[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] @T.prim_func(private=True) @@ -1953,8 +1953,8 @@ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): for i0, i1 in T.grid(T.int64(10), T.int64(20)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(A[v_i0, v_i1]) - T.writes(A[v_i0, v_i1]) + # T.reads(A[v_i0, v_i1]) + # T.writes(A[v_i0, v_i1]) A[v_i0, v_i1] = T.exp(A[v_i0, v_i1]) @T.prim_func(private=True) @@ -1963,8 +1963,8 @@ def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.block("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A[v_ax0, v_ax1]) - T.writes(A[v_ax0, v_ax1]) + # T.reads(A[v_ax0, v_ax1]) + # T.writes(A[v_ax0, v_ax1]) A[v_ax0, v_ax1] = A[v_ax0, v_ax1] @R.function(private=True) @@ -1974,7 +1974,11 @@ def fused_add_exp_squeeze( R.func_attr({"Primitive": 1}) cls = Module with R.dataflow(): - # this overwrites x and is actually evil but we are doing it just to test the pass + # This overwrites x and is actually evil because the function is marked as pure + # but we are doing it just to test the pass. The automatic DataflowUseInplaceCalls + # transformation will not produce code like this, but it may make sense to do it + # if ownership of x is fully and truly transferred. + # Users should apply with caution! lv = R.call_tir_inplace( cls.add_inplace, (x, p0), @@ -2018,20 +2022,14 @@ def fused_add_exp_squeeze( for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.block("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(x[v_ax0, v_ax1], p0[()]) - T.writes(x[v_ax0, v_ax1]) x[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] for i0, i1 in T.grid(T.int64(10), T.int64(20)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(x[v_i0, v_i1]) - T.writes(x[v_i0, v_i1]) x[v_i0, v_i1] = T.exp(x[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.block("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(x[v_ax0, v_ax1]) - T.writes(x[v_ax0, v_ax1]) x[v_ax0, v_ax1] = x[v_ax0, v_ax1] # note that this will clobber x! Use with caution @@ -2068,8 +2066,6 @@ def add( for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.block("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A[v_ax0, v_ax1], B[()]) - T.writes(Out[v_ax0, v_ax1]) Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] @T.prim_func(private=True) @@ -2078,8 +2074,6 @@ def exp_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): for i0, i1 in T.grid(T.int64(10), T.int64(20)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(A[v_i0, v_i1]) - T.writes(A[v_i0, v_i1]) A[v_i0, v_i1] = T.exp(A[v_i0, v_i1]) @T.prim_func(private=True) @@ -2088,8 +2082,6 @@ def squeeze_inplace(A: T.Buffer((T.int64(10), T.int64(20)), "float32")): for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.block("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(A[v_ax0, v_ax1]) - T.writes(A[v_ax0, v_ax1]) A[v_ax0, v_ax1] = A[v_ax0, v_ax1] @R.function(private=True) @@ -2143,20 +2135,14 @@ def fused_add_exp_squeeze( for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.block("T_add"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(x[v_ax0, v_ax1], p0[()]) - T.writes(p_output0[v_ax0, v_ax1]) p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] for i0, i1 in T.grid(T.int64(10), T.int64(20)): with T.block("compute"): v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) - T.reads(p_output0[v_i0, v_i1]) - T.writes(p_output0[v_i0, v_i1]) p_output0[v_i0, v_i1] = T.exp(p_output0[v_i0, v_i1]) for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): with T.block("T_squeeze"): v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) - T.reads(p_output0[v_ax0, v_ax1]) - T.writes(p_output0[v_ax0, v_ax1]) p_output0[v_ax0, v_ax1] = p_output0[v_ax0, v_ax1] @R.function From 0a836ab8edadf53362d0d1abcae1178607a2b392 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 30 Jan 2024 19:04:27 -0500 Subject: [PATCH 6/8] Use a set to ensure in-place indices will be unique --- src/relax/transform/fuse_tir.cc | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 828e0406f7b4..bb8b530d1006 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -379,7 +379,11 @@ class FusedTIRConstructor : public ExprVisitor { CHECK(f->HasNonzeroAttr(relax::attr::kPrimitive)) << "Expected a function with attr `kPrimitive`"; visitor(Downcast(f)); - return {visitor.fused_tir_, visitor.inplace_indices_}; + Array inplace_indices; + for (size_t idx : visitor.inplace_indices_) { + inplace_indices.push_back(Integer(idx)); + } + return {visitor.fused_tir_, inplace_indices}; } private: @@ -444,10 +448,10 @@ class FusedTIRConstructor : public ExprVisitor { const Array& buffers = (*it).second; // map of input buffers to indices (helpful for detecting in-place inputs) - std::unordered_map buffer_to_idx; - std::unordered_map input_to_idx; + std::unordered_map buffer_to_idx; + std::unordered_map input_to_idx; for (size_t i = 0; i < func_info_.params.size(); i++) { - input_to_idx[func_info_.params[i]] = Integer(i); + input_to_idx[func_info_.params[i]] = i; } for (auto [var, buffer] : func_info_.buffer_map) { if (auto it = input_to_idx.find(var); it != input_to_idx.end()) { @@ -463,7 +467,7 @@ class FusedTIRConstructor : public ExprVisitor { // (i.e., already listed in the buffer map. This would result // in duplicates in the buffer map otherwise) if (auto it = buffer_to_idx.find(buffers[i]); it != buffer_to_idx.end()) { - inplace_indices_.push_back((*it).second); + inplace_indices_.insert((*it).second); continue; } @@ -933,7 +937,7 @@ class FusedTIRConstructor : public ExprVisitor { /*! \brief The tir function after fusion*/ tir::PrimFunc fused_tir_; /*! \brief Indices of inputs that are used for in-place computation */ - Array inplace_indices_; + std::unordered_set inplace_indices_; }; std::vector GetTupleAccessedIndices(const FunctionNode* func, const Var& tuple_var) { From 3b44ff1cc243ff9546c3cd7a0d0c9a4917b4d97b Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 30 Jan 2024 19:38:19 -0500 Subject: [PATCH 7/8] Add test case where PrimFunc is used both in-place and DPS --- tests/python/relax/test_transform_fuse_tir.py | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 0113a4711cd1..c0a6f4448b5c 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -2162,5 +2162,97 @@ def main( _check(Module, Expected) +def test_use_as_inplace_and_dps(): + @I.ir_module + class Module: + # we will use it both in-place and normally (DPS) + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(10), T.int64(20)), "float32"), + B: T.Buffer((), "float32"), + Out: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + Out[v_ax0, v_ax1] = A[v_ax0, v_ax1] + B[()] + + @R.function(private=True) + def fused_sums( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + R.func_attr({"Primitive": 1}) + cls = Module + with R.dataflow(): + lv = R.call_tir( + cls.add, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + lv1 = R.call_tir_inplace( + cls.add, + (x, p0, lv), + inplace_indices=[2], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + lv2 = R.call_tir_inplace( + cls.add, + (x, p0, lv1), + inplace_indices=[2], + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(lv2) + return lv2 + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Module + with R.dataflow(): + gv1: R.Tensor((10, 20), dtype="float32") = cls.fused_sums(x, p0) + R.output(gv1) + return gv1 + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def fused_sums( + x: T.Buffer((T.int64(10), T.int64(20)), "float32"), + p0: T.Buffer((), "float32"), + p_output0: T.Buffer((T.int64(10), T.int64(20)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] + for ax0, ax1 in T.grid(T.int64(10), T.int64(20)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + p_output0[v_ax0, v_ax1] = x[v_ax0, v_ax1] + p0[()] + + @R.function + def main( + x: R.Tensor((10, 20), dtype="float32"), p0: R.Tensor((), dtype="float32") + ) -> R.Tensor((10, 20), dtype="float32"): + cls = Expected + with R.dataflow(): + gv1: R.Tensor((10, 20), dtype="float32") = R.call_tir( + cls.fused_sums, + (x, p0), + out_sinfo=R.Tensor((10, 20), dtype="float32"), + ) + R.output(gv1) + return gv1 + + _check(Module, Expected) + + if __name__ == "__main__": tvm.testing.main() From 111f08e183da0892dec47a7a2eae89f63e71ef3a Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Tue, 30 Jan 2024 19:47:37 -0500 Subject: [PATCH 8/8] Explicitly check for duplicate index --- src/relax/transform/fuse_tir.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index bb8b530d1006..4ad291e91cce 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -467,7 +467,10 @@ class FusedTIRConstructor : public ExprVisitor { // (i.e., already listed in the buffer map. This would result // in duplicates in the buffer map otherwise) if (auto it = buffer_to_idx.find(buffers[i]); it != buffer_to_idx.end()) { - inplace_indices_.insert((*it).second); + auto idx = (*it).second; + CHECK(!inplace_indices_.count(idx)) + << "In-place index " << idx << " used twice! An argument must be aliased."; + inplace_indices_.insert(idx); continue; }