From 2f936228aff6822ac03fd7f9849fc0512ce66284 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 2 Jul 2023 16:06:49 -0700 Subject: [PATCH] [TIR][Transform] Add LiftThreadBinding Pass This PR adds a pass LiftThreadBinding to TIR. Previously, during GPU cross-thread reduction, a temporary local buffer will be created in the RF buffer, as a concrete example: ```python rf_local = T.alloc_buffer(..., scope="local") // Step 1. Data parallel RF block for tx in T.thread_binding(..., thread="threadIdx.x") rf_local[tx, ...] = // Step 2. Cross-thread reduction to accumuate rf_local for ...: for tx' in T.thread_binding(..., thread="threadIdx.x"): ... += rf_local[tx', ...] ``` In this case, the buffer `rf_local` will only be accessed by a single point `tx` or `tx'`, but during the pass `CompactBuffeRegion`, the two variables as thread bindings are treated as two separate variables, i.e. the information that `tx` and `tx'` are always identical to each other is discarded, which means the accessed region on `rf_local` are estimated as `Union({tx}, {tx'})` as opposed to `{tx}`, leading over allocation of local registers. This pass is introduced to address this issue by lifting thread bindings to their LCAs. --- include/tvm/tir/transform.h | 6 + include/tvm/tir/var.h | 6 + python/tvm/tir/transform/transform.py | 11 + src/driver/driver_api.cc | 1 + src/meta_schedule/postproc/verify_gpu_code.cc | 1 + src/tir/ir/expr.cc | 8 +- src/tir/transforms/lift_thread_binding.cc | 195 ++++++++++++++++++ .../test_tir_transform_lift_thread_binding.py | 139 +++++++++++++ 8 files changed, 365 insertions(+), 2 deletions(-) create mode 100644 src/tir/transforms/lift_thread_binding.cc create mode 100644 tests/python/unittest/test_tir_transform_lift_thread_binding.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index b315f709de1a..a1697d807db9 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -474,6 +474,12 @@ TVM_DLL Pass PlanAndUpdateBufferAllocationLocation(); */ TVM_DLL Pass ConvertBlocksToOpaque(); +/*! + * \brief Lift the same thread bindings to their LCA loops + * \return The pass. + */ +TVM_DLL Pass LiftThreadBinding(); + /*! * \brief Compact the buffer access region by removing the buffer regions that are not accessed, * i.e. narrowing the buffer shape and adjust the access region if necessary. diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 52827f706a56..9cd2bed65739 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -103,6 +103,12 @@ class Var : public PrimExpr { * \param span The location of this object in the source code. */ TVM_DLL explicit Var(String name_hint, Type type_annotation, Span span = Span()); + /*! + * \brief Make a new copy of var with same type, but a different nam + * \param name The new name to be used. + * \return the new Var copy + */ + TVM_DLL Var copy_with_name(const String& name) const; /*! * \brief Make a new copy of var with same type, append suffix * \param suffix The suffix to be appended. diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 41740d9d8e86..0cd54064a7b5 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -831,6 +831,17 @@ def ConvertBlocksToOpaque(): return _ffi_api.ConvertBlocksToOpaque() # type: ignore +def LiftThreadBinding(): + """Lift the same thread bindings to their LCA loops. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LiftThreadBinding() # type: ignore + + def CompactBufferAllocation(is_strict: bool = True): """Compact the buffer access region. by removing the buffer regions that are not accessed, i.e. narrowing the buffer shape and adjust diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 0b766f1dd518..d46fab716814 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -202,6 +202,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::LiftThreadBinding()); pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerAutoCopy()); diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 14e4b0f01e25..2fb97d32eb74 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -161,6 +161,7 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::LiftThreadBinding()); pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::Simplify()); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 9219dde2291b..d590f8b2dd8b 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -80,7 +80,7 @@ Var::Var(String name_hint, Type type_annotation, Span span) { data_ = std::move(n); } -Var Var::copy_with_suffix(const String& suffix) const { +Var Var::copy_with_name(const String& name) const { const VarNode* node = get(); ObjectPtr new_ptr; if (auto* ptr = this->as()) { @@ -88,10 +88,14 @@ Var Var::copy_with_suffix(const String& suffix) const { } else { new_ptr = make_object(*node); } - new_ptr->name_hint = new_ptr->name_hint + suffix; + new_ptr->name_hint = name; return Var(new_ptr); } +Var Var::copy_with_suffix(const String& suffix) const { + return this->copy_with_name(get()->name_hint + suffix); +} + Var Var::copy_with_dtype(DataType dtype) const { const VarNode* node = get(); ObjectPtr new_ptr; diff --git a/src/tir/transforms/lift_thread_binding.cc b/src/tir/transforms/lift_thread_binding.cc new file mode 100644 index 000000000000..9d7d455dbaed --- /dev/null +++ b/src/tir/transforms/lift_thread_binding.cc @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file convert_block_to_opaque.cc + * \brief Convert the blocks to opaque blocks which do not have block vars. + */ + +#include +#include + +#include "../../runtime/thread_storage_scope.h" +#include "./ir_utils.h" + +namespace tvm { +namespace tir { + +std::pair>>, + ObjectPtrHash, ObjectPtrEqual>, + Map> +FindLoopLCA(const Stmt& root) { + class LCAFinder : public StmtVisitor { + public: + void VisitStmt_(const ForNode* op) final { + stack.push_back(GetRef(op)); + StmtVisitor::VisitStmt_(op); + if (op->kind == ForKind::kThreadBinding) { + UpdateLCA(op); + } + stack.pop_back(); + } + + void UpdateLCA(const ForNode* loop) { + std::string thread_tag = loop->thread_binding.value()->thread_tag; + { + Map* tgt = &annotations[thread_tag]; + for (const auto& kv : loop->annotations) { + tgt->Set(kv.first, kv.second); + } + } + IterVar& iter_var = iters[thread_tag]; + if (!iter_var.defined()) { + iter_var = IterVar(Range::FromMinExtent(loop->min, loop->extent), // + loop->loop_var.copy_with_name(thread_tag), // + loop->thread_binding.value()->iter_type, // + thread_tag); + lca[thread_tag] = stack; + var_subst.Set(loop->loop_var, iter_var->var); + return; + } + var_subst.Set(loop->loop_var, iter_var->var); + std::vector& path = lca[thread_tag]; + uint32_t i = 0; + for (; i < stack.size() && i < path.size(); ++i) { + if (!stack[i].same_as(path[i])) { + break; + } + } + path.resize(i); + } + + std::unordered_map> lca; + std::unordered_map iters; + std::unordered_map> annotations; + Map var_subst; + std::vector stack; + }; + LCAFinder finder; + finder(root); + std::unordered_map>>, ObjectPtrHash, + ObjectPtrEqual> + result; + std::vector sorted_thread_tags; + for (const auto& kv : finder.lca) { + sorted_thread_tags.push_back(kv.first); + } + std::sort(sorted_thread_tags.begin(), sorted_thread_tags.end(), + [](const std::string& lhs, const std::string& rhs) { + return lhs.size() > rhs.size(); + runtime::ThreadScope lhs_scope = runtime::ThreadScope::Create(lhs); + runtime::ThreadScope rhs_scope = runtime::ThreadScope::Create(rhs); + if (lhs_scope.rank != rhs_scope.rank) { + return lhs_scope.rank < rhs_scope.rank; + } + return lhs_scope.dim_index < rhs_scope.dim_index; + }); + for (const auto& thread_tag : sorted_thread_tags) { + Stmt lca = finder.lca[thread_tag].back(); + const IterVar& iter = finder.iters[thread_tag]; + const Map& annotations = finder.annotations[thread_tag]; + result[lca].emplace_back(iter, annotations); + } + return {result, finder.var_subst}; +} + +/*! + * \brief Substitute expr via BlockRealize value bindings and convert each block into opaque + * blocks. + */ +class ThreadBindingLifter : public StmtExprMutator { + public: + Stmt VisitStmt_(const ForNode* _op) final { + For op = GetRef(_op); + bool is_kernel_root = false; + if (op->kind == ForKind::kThreadBinding) { + if (iter_lca.empty()) { + is_kernel_root = true; + SetKernelRoot(_op); + } + } + For new_op = Downcast(StmtExprMutator::VisitStmt_(_op)); + Stmt body = std::move(new_op.CopyOnWrite()->body); + if (auto it = iter_lca.find(op); it != iter_lca.end()) { + for (const auto& [iter_var, annotation] : it->second) { + body = For(iter_var->var, iter_var->dom->min, iter_var->dom->extent, + ForKind::kThreadBinding, std::move(body), + IterVar(Range(nullptr), Var(iter_var->thread_tag, iter_var->var->dtype), + kThreadIndex, iter_var->thread_tag), + annotation); + } + } + if (is_kernel_root) { + iter_lca.clear(); + var_subst.clear(); + } + if (op->kind == ForKind::kThreadBinding) { + return body; + } else { + new_op.CopyOnWrite()->body = std::move(body); + return new_op; + } + } + + void SetKernelRoot(const ForNode* op) { + auto result = FindLoopLCA(GetRef(op)); + this->iter_lca = std::move(result.first); + this->var_subst = std::move(result.second); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = var_subst.find(GetRef(op)); + if (it != var_subst.end()) { + return (*it).second; + } else { + return GetRef(op); + } + } + + std::unordered_map>>, ObjectPtrHash, + ObjectPtrEqual> + iter_lca; + Map var_subst; +}; + +PrimFunc LiftThreadBinding(PrimFunc f) { + // Only apply this pass to TIR that is not from TE schedules + if (!IsFromLegacyTESchedule(f)) { + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = ThreadBindingLifter()(std::move(fptr->body)); + return f; + } else { + return f; + } +} + +namespace transform { + +Pass LiftThreadBinding() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return LiftThreadBinding(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.LiftThreadBinding", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.LiftThreadBinding").set_body_typed(LiftThreadBinding); +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_lift_thread_binding.py b/tests/python/unittest/test_tir_transform_lift_thread_binding.py new file mode 100644 index 000000000000..defcc6d6c1dc --- /dev/null +++ b/tests/python/unittest/test_tir_transform_lift_thread_binding.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir +from tvm.script import tir as T + + +def test_lift_tx_beyond_local(): + # fmt: off + @T.prim_func + def before(a: T.handle, b: T.handle, c: T.handle): + n = T.int32() + A = T.match_buffer(a, (32, 1, 128)) + B = T.match_buffer(b, (32, n, 128)) + C = T.match_buffer(c, (32, 1, n)) + for ax0_ax1_fused in T.thread_binding(n * 32, thread="blockIdx.x"): + with T.block(""): + T.reads(A[ax0_ax1_fused // n, 0, 0:256], B[ax0_ax1_fused // n, ax0_ax1_fused % n, 0:256]) + T.writes(C[ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) + D_local = T.alloc_buffer((32, 1, n), scope="local") + D_rf_local = T.alloc_buffer((256, 32, 1, n), scope="local") + for ax2_fused_1 in T.thread_binding(256, thread="threadIdx.x"): + with T.block("NT_matmul_rf_init"): + T.reads() + T.writes(D_rf_local[ax2_fused_1, ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) + D_rf_local[ax2_fused_1, ax0_ax1_fused // n, 0, ax0_ax1_fused % n] = T.float32(0) + for ax2_fused_0 in range(1): + with T.block("NT_matmul_rf_update"): + T.where(ax2_fused_0 * 256 + ax2_fused_1 < 128) + T.reads(D_rf_local[ax2_fused_1, ax0_ax1_fused // n, 0, ax0_ax1_fused % n], A[ax0_ax1_fused // n, 0, ax2_fused_0 * 256 + ax2_fused_1], B[ax0_ax1_fused // n, ax0_ax1_fused % n, ax2_fused_0 * 256 + ax2_fused_1]) + T.writes(D_rf_local[ax2_fused_1, ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) + D_rf_local[ax2_fused_1, ax0_ax1_fused // n, 0, ax0_ax1_fused % n] = D_rf_local[ax2_fused_1, ax0_ax1_fused // n, 0, ax0_ax1_fused % n] + A[ax0_ax1_fused // n, 0, ax2_fused_0 * 256 + ax2_fused_1] * B[ax0_ax1_fused // n, ax0_ax1_fused % n, ax2_fused_0 * 256 + ax2_fused_1] + for ax1_ax2_fused in range(1): + for ax0_fused in T.thread_binding(256, thread="threadIdx.x"): + with T.block(""): + T.reads(D_rf_local[ax0_fused, ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) + T.writes(D_local[ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) + cross_thread_D_local = T.alloc_buffer((1,), strides=(1,), scope="local") + in_thread_D_local = T.alloc_buffer((1,), strides=(1,), scope="local") + with T.block("NT_matmul_in_thread_init"): + T.reads() + T.writes(in_thread_D_local[0]) + in_thread_D_local[0] = T.float32(0) + with T.block("NT_matmul_in_thread"): + T.where(0 <= ax0_ax1_fused // n and ax0_ax1_fused // n < 32 and 0 <= ax0_ax1_fused % n and ax0_ax1_fused % n < n) + T.reads(D_rf_local[ax0_fused, ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) + T.writes(in_thread_D_local[0]) + in_thread_D_local[0] = in_thread_D_local[0] + D_rf_local[ax0_fused, ax0_ax1_fused // n, 0, ax0_ax1_fused % n] + with T.block("NT_matmul_cross_thread"): + T.reads(in_thread_D_local[0]) + T.writes(cross_thread_D_local[0]) + T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) + T.tvm_thread_allreduce(T.uint32(1), in_thread_D_local[0], T.bool(True), cross_thread_D_local[0], ax0_fused) + with T.block("NT_matmul_write_back"): + T.where(ax0_fused == 0) + T.reads(cross_thread_D_local[0]) + T.writes(D_local[ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) + D_local[ax0_ax1_fused // n, 0, ax0_ax1_fused % n] = cross_thread_D_local[0] + with T.block("T_divide"): + T.where(0 <= ax0_ax1_fused // n and ax0_ax1_fused // n < 32 and 0 <= ax0_ax1_fused % n and ax0_ax1_fused % n < n) + T.reads(D_local[ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) + T.writes(C[ax0_ax1_fused // n, 0, ax0_ax1_fused % n]) + C[ax0_ax1_fused // n, 0, ax0_ax1_fused % n] = D_local[ax0_ax1_fused // n, 0, ax0_ax1_fused % n] * T.float32(0.088397790055248615) + + @T.prim_func + def expected(A: T.Buffer((32, 1, 128), "float32"), b: T.handle, c: T.handle): + n = T.int32() + B = T.match_buffer(b, (32, n, 128)) + C = T.match_buffer(c, (32, 1, n)) + # with T.block("root"): + for blockIdx_x in T.thread_binding(n * 32, thread="blockIdx.x"): + for threadIdx_x in T.thread_binding(256, thread="threadIdx.x"): + with T.block(""): + T.reads(A[blockIdx_x // n, 0, 0:256], B[blockIdx_x // n, blockIdx_x % n, 0:256]) + T.writes(C[blockIdx_x // n, 0, blockIdx_x % n]) + D_local = T.alloc_buffer((32, 1, n), scope="local") + D_rf_local = T.alloc_buffer((256, 32, 1, n), scope="local") + with T.block("NT_matmul_rf_init"): + T.reads() + T.writes(D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n]) + D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n] = T.float32(0) + for ax2_fused_0 in range(1): + with T.block("NT_matmul_rf_update"): + T.where(ax2_fused_0 * 256 + threadIdx_x < 128) + T.reads(D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n], A[blockIdx_x // n, 0, ax2_fused_0 * 256 + threadIdx_x], B[blockIdx_x // n, blockIdx_x % n, ax2_fused_0 * 256 + threadIdx_x]) + T.writes(D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n]) + D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n] = D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n] + A[blockIdx_x // n, 0, ax2_fused_0 * 256 + threadIdx_x] * B[blockIdx_x // n, blockIdx_x % n, ax2_fused_0 * 256 + threadIdx_x] + for ax1_ax2_fused in range(1): + with T.block(""): + T.reads(D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n]) + T.writes(D_local[blockIdx_x // n, 0, blockIdx_x % n]) + cross_thread_D_local = T.alloc_buffer((1,), strides=(1,), scope="local") + in_thread_D_local = T.alloc_buffer((1,), strides=(1,), scope="local") + with T.block("NT_matmul_in_thread_init"): + T.reads() + T.writes(in_thread_D_local[0]) + in_thread_D_local[0] = T.float32(0) + with T.block("NT_matmul_in_thread"): + T.where(0 <= blockIdx_x // n and blockIdx_x // n < 32 and 0 <= blockIdx_x % n and blockIdx_x % n < n) + T.reads(D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n]) + T.writes(in_thread_D_local[0]) + in_thread_D_local[0] = in_thread_D_local[0] + D_rf_local[threadIdx_x, blockIdx_x // n, 0, blockIdx_x % n] + with T.block("NT_matmul_cross_thread"): + T.reads(in_thread_D_local[0]) + T.writes(cross_thread_D_local[0]) + T.attr(T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), "reduce_scope", T.reinterpret("handle", T.uint64(0))) + T.tvm_thread_allreduce(T.uint32(1), in_thread_D_local[0], T.bool(True), cross_thread_D_local[0], threadIdx_x) + with T.block("NT_matmul_write_back"): + T.where(threadIdx_x == 0) + T.reads(cross_thread_D_local[0]) + T.writes(D_local[blockIdx_x // n, 0, blockIdx_x % n]) + D_local[blockIdx_x // n, 0, blockIdx_x % n] = cross_thread_D_local[0] + with T.block("T_divide"): + T.where(0 <= blockIdx_x // n and blockIdx_x // n < 32 and 0 <= blockIdx_x % n and blockIdx_x % n < n) + T.reads(D_local[blockIdx_x // n, 0, blockIdx_x % n]) + T.writes(C[blockIdx_x // n, 0, blockIdx_x % n]) + C[blockIdx_x // n, 0, blockIdx_x % n] = D_local[blockIdx_x // n, 0, blockIdx_x % n] * T.float32(0.088397790055248615) + # fmt: on + mod = tvm.IRModule({"main": before}) + after = tir.transform.LiftThreadBinding()(mod) + tvm.ir.assert_structural_equal(expected, after["main"]) + + +if __name__ == "__main__": + test_lift_tx_beyond_local()