From f180077eba6a261c74f1e66a8d810aa0e076b08d Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 14 Apr 2025 21:25:59 -0400 Subject: [PATCH] [BugFix][TIR] Schedule support reverse-inline with reduction blocks This PR fixes a bug in reverse-compute-inline of tir Schedule, which generates incorrect TIR after inlining a transpose block into a reduction block. --- src/tir/schedule/primitive/compute_inline.cc | 28 ++++++++++- .../test_tir_schedule_compute_inline.py | 49 +++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index df74497b4a69..85f3a0f82f76 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -586,6 +586,30 @@ class ReverseComputeInliner : public BaseInliner { ReverseComputeInliner* self_; }; + class RecursionResolver : public StmtExprMutator { + public: + explicit RecursionResolver(ReverseComputeInliner* self) : self_(self) {} + + private: + PrimExpr VisitExpr_(const VarNode* var) final { + auto it = self_->idx_sub_.find(var); + if (it == self_->idx_sub_.end()) { + return GetRef(var); + } + return (*it).second; + } + + PrimExpr VisitExpr_(const BufferLoadNode* _load) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_load)); + return load->buffer.same_as(self_->inlined_buffer_) + ? StmtExprMutator::VisitExpr( + BufferLoad(self_->inlined_store_->buffer, self_->inlined_store_->indices)) + : load; + } + + ReverseComputeInliner* self_; + }; + public: explicit ReverseComputeInliner(const Buffer& inlined_buffer, const BlockNode* producer_block, const BlockRealize& consumer_block_realize, @@ -784,7 +808,9 @@ class ReverseComputeInliner : public BaseInliner { } Stmt ReplaceInlinedBuffer(BufferStore producer) { - producer_rhs_ = producer->value; + // "producer->value" may contain the buffer that is inlined in cases of reduction, + // so we need to resolve the recursion first + producer_rhs_ = RecursionResolver(this)(producer->value); return Substituter(this)(GetRef(inlined_store_)); } diff --git a/tests/python/tir-schedule/test_tir_schedule_compute_inline.py b/tests/python/tir-schedule/test_tir_schedule_compute_inline.py index 2f779612a72a..066070f763ee 100644 --- a/tests/python/tir-schedule/test_tir_schedule_compute_inline.py +++ b/tests/python/tir-schedule/test_tir_schedule_compute_inline.py @@ -1529,5 +1529,54 @@ def after( assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) +def test_inline_with_reduction(): + @T.prim_func + def before( + T_softmax_norm: T.Buffer((T.int64(6), T.int64(1), T.int64(1)), "float32"), + T_reshape_2: T.Buffer((T.int64(6), T.int64(1), T.int64(64)), "float32"), + T_transpose: T.Buffer((T.int64(1), T.int64(1), T.int64(6), T.int64(64)), "float32"), + ): + T_batch_matmul_NN = T.alloc_buffer((T.int64(6), T.int64(1), T.int64(64))) + for ax0, ax1 in T.grid(T.int64(6), T.int64(64)): + with T.block("bmm"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_softmax_norm[v0, T.int64(0), T.int64(0)], T_reshape_2[v0, T.int64(0), v1]) + T.writes(T_batch_matmul_NN[v0, T.int64(0), v1]) + with T.init(): + T_batch_matmul_NN[v0, T.int64(0), v1] = T.float32(0.0) + T_batch_matmul_NN[v0, T.int64(0), v1] = ( + T_batch_matmul_NN[v0, T.int64(0), v1] + + T_softmax_norm[v0, T.int64(0), T.int64(0)] * T_reshape_2[v0, T.int64(0), v1] + ) + for ax0, ax1 in T.grid(T.int64(6), T.int64(64)): + with T.block("transpose"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_batch_matmul_NN[v0, T.int64(0), v1]) + T.writes(T_transpose[T.int64(0), T.int64(0), v0, v1]) + T_transpose[T.int64(0), T.int64(0), v0, v1] = T_batch_matmul_NN[v0, T.int64(0), v1] + + @T.prim_func + def after( + T_softmax_norm: T.Buffer((T.int64(6), T.int64(1), T.int64(1)), "float32"), + T_reshape_2: T.Buffer((T.int64(6), T.int64(1), T.int64(64)), "float32"), + T_transpose: T.Buffer((T.int64(1), T.int64(1), T.int64(6), T.int64(64)), "float32"), + ): + for ax0, ax1 in T.grid(T.int64(6), T.int64(64)): + with T.block("bmm"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(T_softmax_norm[v0, T.int64(0), T.int64(0)], T_reshape_2[v0, T.int64(0), v1]) + T.writes(T_transpose[T.int64(0), T.int64(0), v0, v1]) + with T.init(): + T_transpose[T.int64(0), T.int64(0), v0, v1] = T.float32(0.0) + T_transpose[T.int64(0), T.int64(0), v0, v1] = ( + T_transpose[T.int64(0), T.int64(0), v0, v1] + + T_softmax_norm[v0, T.int64(0), T.int64(0)] * T_reshape_2[v0, T.int64(0), v1] + ) + + sch = tir.Schedule(before) + sch.reverse_compute_inline(sch.get_block("transpose")) + assert_structural_equal_ignore_global_symbol(after, sch.mod["main"]) + + if __name__ == "__main__": tvm.testing.main()