From 49e3e6e9542d01e5562c5ea67a6458a3ca0a13c0 Mon Sep 17 00:00:00 2001 From: adstraw Date: Mon, 10 Oct 2022 14:55:37 -0700 Subject: [PATCH 1/3] [Hexagon] Enable multi input Async DMA; same queue / stage --- .../transforms/inject_software_pipeline.cc | 31 +-- src/tir/transforms/lower_async_dma.cc | 7 +- .../test_software_pipeline_async.py | 176 +++++++++++++----- 3 files changed, 143 insertions(+), 71 deletions(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 08d57c53d1c2..bd2f8697d00e 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -762,11 +762,12 @@ class PipelineRewriter : public StmtExprMutator { << "Predicates in the same stage are expected to be identical"; group_bodies.push_back(new_blocks[i].block->body); } - auto body = group_bodies.size() > 1 ? SeqStmt(group_bodies) : group_bodies[0]; - auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), - tir::attr::async_commit_queue_scope, stage_id, body); - auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); - stmts.push_back(BlockRealize({}, predicate, new_block)); + for (auto body : group_bodies) { + auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), + tir::attr::async_commit_queue_scope, stage_id, body); + auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); + stmts.push_back(BlockRealize({}, predicate, new_block)); + } } } @@ -841,24 +842,8 @@ class PipelineRewriter : public StmtExprMutator { if (pipeline_info_[block].async) { auto& local_state = async_states_local[stage]; - int commit_group_id = -1; - if (local_state.commit_groups.empty() || local_state.consumed) { - // consumed == true means there is already a consumer stage waiting for an - // eariler async operation of this stage. In such cases, we make multiple commit_queue - // for this stage. - commit_group_id = local_state.commit_groups.size(); - local_state.commit_groups.push_back({new_blocks.size()}); - } else { - // This is the case when one commit_queue groups multiple async blocks. - // with commit_queue(stage): - // async_scope: - // A_shared[...] = ... - // async_scope: - // B_shared[...] = ... - - commit_group_id = local_state.commit_groups.size() - 1; - local_state.commit_groups.back().push_back(new_blocks.size()); - } + int commit_group_id = local_state.commit_groups.size(); + local_state.commit_groups.push_back({new_blocks.size()}); for (auto write_region : new_block->writes) { async_states[stage].dst_buffers.insert(write_region->buffer.get()); diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index 78d363f67c02..417e9d61f263 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -94,9 +94,6 @@ class AsyncDMALowerer : public StmtExprMutator { ICHECK(queue_id_node); int queue_id = queue_id_node->value; - // save queue ID for inspection in `wait` transform - queue_ids.insert(queue_id); - // walk the graph to verify this is a mem copy ... // 1) async_commit_queue_scope contains async_scope auto async_scope = op->body.as(); @@ -161,6 +158,10 @@ class AsyncDMALowerer : public StmtExprMutator { return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); }); + // now that we are about to perform the `copy` transform + // save queue ID for inspection in `wait` transform + queue_ids.insert(queue_id); + return Evaluate(Call(DataType::Int(32), builtin::dma_copy(), {queue_id, Call(DataType::Handle(), builtin::address_of(), diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index 943d4262f9da..0d6d3ff15602 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -30,81 +30,167 @@ inner = tvm.testing.parameter(64, 128) dtype = tvm.testing.parameter("uint8", "float16") scope = tvm.testing.parameter("global", "global.vtcm") +# TODO(Joseph) Turn on "multi_input_diffQ" compute type once we have upstreamed +# changes in the InjectSoftwarePipeline pass to alleviate this restriction: +# 'A dependency on multiple async stages is not supported' +comp_type = tvm.testing.parameter("single_input", "multi_input_sameQ") # TODO(Straw) Add back "cache_write" schedule type once we have upstreamed # buffer dependency analysis in InjectSoftwarePipeline pass # to insert approprite TIR "wait" attributes for this schedule -sched = tvm.testing.parameter("cache_read", "cache_read_write") +sched_type = tvm.testing.parameter("cache_read", "cache_read_write") @tvm.testing.fixture -def compute(outer, inner, dtype): - @T.prim_func - def plus_one_primfunc(A: T.Buffer[(outer, inner), dtype], B: T.Buffer[(outer, inner), dtype]): - for i in T.serial(outer): - for j in T.serial(inner): - with T.block("compute"): - with T.block(): - B[i, j] = A[i, j] + T.cast(1, dtype) +def data(comp_type, outer, inner, dtype): + out_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) + a_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) + if comp_type == "single_input": + return out_np, a_np + else: + b_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) + return out_np, a_np, b_np + + +@tvm.testing.fixture +def compute(comp_type, outer, inner, dtype): + if comp_type == "single_input": + + @T.prim_func + def a_plus_1_primfunc( + A: T.Buffer[(outer, inner), dtype], OUT: T.Buffer[(outer, inner), dtype] + ): + for i in T.serial(outer): + for j in T.serial(inner): + with T.block("compute"): + with T.block(): + OUT[i, j] = A[i, j] + T.cast(1, dtype) + + return a_plus_1_primfunc + else: + + @T.prim_func + def a_plus_b_plus_1_primfunc( + A: T.Buffer[(outer, inner), dtype], + B: T.Buffer[(outer, inner), dtype], + OUT: T.Buffer[(outer, inner), dtype], + ): + for i in T.serial(outer): + for j in T.serial(inner): + with T.block("compute"): + with T.block(): + OUT[i, j] = A[i, j] + B[i, j] + T.cast(1, dtype) + + return a_plus_b_plus_1_primfunc + + +@tvm.testing.fixture +def reference(comp_type): + if comp_type == "single_input": + + def a_plus_1_ref(a): + return a + 1 + + return a_plus_1_ref + else: - def plus_one_ref(a): - return a + 1 + def a_plus_b_plus_1_ref(a, b): + return a + b + 1 - return plus_one_primfunc, plus_one_ref + return a_plus_b_plus_1_ref @tvm.testing.fixture -def schedule(compute, sched, scope): - sch = tir.Schedule(compute[0]) +def schedule(comp_type, compute, sched_type, scope): + sch = tir.Schedule(compute) compute_block = sch.get_block("compute") i, _ = sch.get_loops(compute_block) - if sched == "cache_read": - cache_read_block = sch.cache_read(compute_block, 0, scope) - sch.compute_at(cache_read_block, i) - sch.annotate(i, "software_pipeline_stage", [0, 1]) - sch.annotate(i, "software_pipeline_order", [0, 1]) - sch.annotate(i, "software_pipeline_async_stages", [0]) - elif sched == "cache_write": - cache_write_block = sch.cache_write(compute_block, 0, scope) - sch.reverse_compute_at(cache_write_block, i) + if "read" in sched_type: + cache_read_a = sch.cache_read(compute_block, 0, scope) + sch.compute_at(cache_read_a, i) + + if "multi_input" in comp_type: + cache_read_b = sch.cache_read(compute_block, 1, scope) + sch.compute_at(cache_read_b, i) + + if "write" in sched_type: + cache_write_out = sch.cache_write(compute_block, 0, scope) + sch.reverse_compute_at(cache_write_out, i) + + if "read" in sched_type and "write" in sched_type: + if comp_type == "single_input": + sch.annotate(i, "software_pipeline_stage", [0, 1, 2]) + sch.annotate(i, "software_pipeline_order", [0, 1, 2]) + sch.annotate(i, "software_pipeline_async_stages", [0, 2]) + elif comp_type == "multi_input_sameQ": + sch.annotate(i, "software_pipeline_stage", [0, 0, 1, 2]) + sch.annotate(i, "software_pipeline_order", [0, 1, 2, 3]) + sch.annotate(i, "software_pipeline_async_stages", [0, 2]) + elif comp_type == "multi_input_diffQ": + sch.annotate(i, "software_pipeline_stage", [0, 1, 2, 3]) + sch.annotate(i, "software_pipeline_order", [0, 1, 2, 3]) + sch.annotate(i, "software_pipeline_async_stages", [0, 1, 2]) + + elif "read" in sched_type: + if comp_type == "single_input": + sch.annotate(i, "software_pipeline_stage", [0, 1]) + sch.annotate(i, "software_pipeline_order", [0, 1]) + sch.annotate(i, "software_pipeline_async_stages", [0]) + elif comp_type == "multi_input_sameQ": + sch.annotate(i, "software_pipeline_stage", [0, 0, 1]) + sch.annotate(i, "software_pipeline_order", [0, 1, 2]) + sch.annotate(i, "software_pipeline_async_stages", [0]) + elif comp_type == "multi_input_diffQ": + sch.annotate(i, "software_pipeline_stage", [0, 1, 2]) + sch.annotate(i, "software_pipeline_order", [0, 1, 2]) + sch.annotate(i, "software_pipeline_async_stages", [0, 1]) + + elif "write" in sched_type: sch.annotate(i, "software_pipeline_stage", [0, 1]) sch.annotate(i, "software_pipeline_order", [0, 1]) sch.annotate(i, "software_pipeline_async_stages", [1]) - elif sched == "cache_read_write": - cache_read_block = sch.cache_read(compute_block, 0, scope) - sch.compute_at(cache_read_block, i) - cache_write_block = sch.cache_write(compute_block, 0, scope) - sch.reverse_compute_at(cache_write_block, i) - sch.annotate(i, "software_pipeline_stage", [0, 1, 2]) - sch.annotate(i, "software_pipeline_order", [0, 1, 2]) - sch.annotate(i, "software_pipeline_async_stages", [0, 2]) return sch -@tvm.testing.requires_hexagon -def test_async_software_pipeline(hexagon_launcher, compute, schedule, outer, inner, dtype, scope): - sch = schedule +@tvm.testing.fixture +def verify(dtype): + def check(out, ref): + if "int" in dtype: + np.testing.assert_equal(out.numpy(), ref) + else: + np.testing.assert_allclose(out.numpy(), ref, rtol=1e-3, atol=1e-3) - a_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) - b_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) - ref = compute[1](a_np) + return check + + +@tvm.testing.requires_hexagon +def test_async_software_pipeline(hexagon_launcher, comp_type, data, reference, schedule, verify): + out_np = data[0] + a_np = data[1] + if comp_type == "single_input": + ref = reference(a_np) + else: + b_np = data[2] + ref = reference(a_np, b_np) with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): - func = tvm.build(sch.mod["main"], target=get_hexagon_target("v68")) + # tvm.lower(schedule.mod["main"]).show() + func = tvm.build(schedule.mod["main"], target=get_hexagon_target("v68")) with hexagon_launcher.start_session() as hexagon_session: dev = hexagon_session.device - a = tvm.nd.array(a_np, device=dev) - b = tvm.nd.array(b_np, device=dev) mod = hexagon_session.load_module(func) - mod(a, b) - - if "int" in dtype: - np.testing.assert_equal(b.numpy(), ref) + out = tvm.nd.array(out_np, device=dev) + a = tvm.nd.array(a_np, device=dev) + if comp_type == "single_input": + mod(a, out) else: - np.testing.assert_allclose(b.numpy(), ref, rtol=1e-3, atol=1e-3) + b = tvm.nd.array(b_np, device=dev) + mod(a, b, out) + + verify(out, ref) if __name__ == "__main__": From 0f07f03c92239801dc2088c5702e8ea6dcafa84d Mon Sep 17 00:00:00 2001 From: adstraw Date: Wed, 12 Oct 2022 09:27:19 -0700 Subject: [PATCH 2/3] add option to merge (or separate) async_commit_queue_scope attrs --- include/tvm/tir/transform.h | 2 +- src/driver/driver_api.cc | 5 +- src/meta_schedule/postproc/verify_gpu_code.cc | 3 +- .../transforms/inject_software_pipeline.cc | 58 ++++++++++++++----- .../test_software_pipeline_async.py | 4 +- 5 files changed, 54 insertions(+), 18 deletions(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 6aa1aca69970..0bf7d768cfaa 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -629,7 +629,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); * * \return The IR transform pass. */ -TVM_DLL Pass InjectSoftwarePipeline(); +TVM_DLL Pass InjectSoftwarePipeline(bool merge_async_commit_queue_scope); TVM_DLL Pass BindParams(const Array& constants); diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index b0af0fb65e16..19c3f1844d15 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -51,6 +51,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool); using runtime::PackedFunc; using runtime::TVMArgs; @@ -202,7 +203,9 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); - pass_list.push_back(tir::transform::InjectSoftwarePipeline()); + bool merge_async_commit_queue_scope = + pass_ctx->GetConfig("tir.merge_async_commit_queue_scope", Bool(true)).value(); + pass_list.push_back(tir::transform::InjectSoftwarePipeline(merge_async_commit_queue_scope)); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16Legalize()); diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 0828ee538427..989acbafa047 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -163,7 +163,8 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); - pass_list.push_back(tir::transform::InjectSoftwarePipeline()); + pass_list.push_back( + tir::transform::InjectSoftwarePipeline(true /* merge_async_commit_queue_scope */)); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16Legalize()); diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index bd2f8697d00e..ac26b8dfab20 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -309,9 +309,10 @@ class PipelineRewriter : public StmtExprMutator { const Array pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, const std::unordered_map& fragment_info, - const Map preserved_annotations) { + const Map preserved_annotations, bool merge_async_commit_queue_scope) { PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop, - pipeline_info, fragment_info, preserved_annotations); + pipeline_info, fragment_info, preserved_annotations, + merge_async_commit_queue_scope); return rewriter.BuildPipeline(); } @@ -321,7 +322,8 @@ class PipelineRewriter : public StmtExprMutator { const Array& pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, const std::unordered_map& fragment_info, - const Map preserved_annotations) + const Map preserved_annotations, + bool merge_async_commit_queue_scope) : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), double_buffers_(double_buffers), @@ -329,7 +331,8 @@ class PipelineRewriter : public StmtExprMutator { pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info), fragment_info_(fragment_info), - preserved_annotations_(preserved_annotations) {} + preserved_annotations_(preserved_annotations), + merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {} Stmt BuildPipeline() { // Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions @@ -762,6 +765,13 @@ class PipelineRewriter : public StmtExprMutator { << "Predicates in the same stage are expected to be identical"; group_bodies.push_back(new_blocks[i].block->body); } + + if (merge_async_commit_queue_scope_ && group_bodies.size() > 1) { + auto merged_bodies = SeqStmt(group_bodies); + group_bodies.clear(); + group_bodies.push_back(merged_bodies); + } + for (auto body : group_bodies) { auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_commit_queue_scope, stage_id, body); @@ -842,8 +852,25 @@ class PipelineRewriter : public StmtExprMutator { if (pipeline_info_[block].async) { auto& local_state = async_states_local[stage]; - int commit_group_id = local_state.commit_groups.size(); - local_state.commit_groups.push_back({new_blocks.size()}); + int commit_group_id = -1; + if (local_state.commit_groups.empty() || local_state.consumed || + !merge_async_commit_queue_scope_) { + // consumed == true means there is already a consumer stage waiting for an + // eariler async operation of this stage. In such cases, we make multiple commit_queue + // for this stage. + commit_group_id = local_state.commit_groups.size(); + local_state.commit_groups.push_back({new_blocks.size()}); + } else { + // This is the case when one commit_queue groups multiple async blocks. + // with commit_queue(stage): + // async_scope: + // A_shared[...] = ... + // async_scope: + // B_shared[...] = ... + + commit_group_id = local_state.commit_groups.size() - 1; + local_state.commit_groups.back().push_back(new_blocks.size()); + } for (auto write_region : new_block->writes) { async_states[stage].dst_buffers.insert(write_region->buffer.get()); @@ -927,6 +954,7 @@ class PipelineRewriter : public StmtExprMutator { Array ordered_stmts_; std::map async_states; Map preserved_annotations_; + bool merge_async_commit_queue_scope_ = true; }; /*! @@ -965,8 +993,8 @@ void BuildDependencyGraph( class PipelineInjector : private StmtExprMutator { public: - static Stmt Inject(const PrimFunc& func) { - PipelineInjector injector; + static Stmt Inject(const PrimFunc& func, bool merge_async_commit_queue_scope) { + PipelineInjector injector(merge_async_commit_queue_scope); for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; injector.buffer_data_to_buffer_.Set(buffer->data, buffer); @@ -976,7 +1004,8 @@ class PipelineInjector : private StmtExprMutator { } private: - PipelineInjector() = default; + explicit PipelineInjector(bool merge_async_commit_queue_scope) + : merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {} /*! * \brief Check the pipeline satisfies the following conditions: @@ -1111,9 +1140,9 @@ class PipelineInjector : private StmtExprMutator { ValidatePipelineBody(pipeline_info, original_order); // Step 4: Rewrite the pipeline body. - Stmt pipeline = PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers, - pipeline_allocs, GetRef(op), pipeline_info, - fragment_info_, preserved_annotations); + Stmt pipeline = PipelineRewriter::Rewrite( + buffer_data_to_buffer_, double_buffers, pipeline_allocs, GetRef(op), pipeline_info, + fragment_info_, preserved_annotations, merge_async_commit_queue_scope_); if (const auto* realize = op->body.as()) { const auto& block = realize->block; @@ -1182,6 +1211,7 @@ class PipelineInjector : private StmtExprMutator { Map buffer_data_to_buffer_; std::unordered_map fragment_info_; std::unordered_set double_buffers; + bool merge_async_commit_queue_scope_ = true; }; } // namespace software_pipeline @@ -1192,10 +1222,10 @@ namespace transform { * \brief Transform annotated loops into pipelined one that parallelize producers and consumers. * \return The IR transform pass. */ -Pass InjectSoftwarePipeline() { +Pass InjectSoftwarePipeline(bool merge_async_commit_queue_scope) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* fptr = f.CopyOnWrite(); - fptr->body = software_pipeline::PipelineInjector::Inject(f); + fptr->body = software_pipeline::PipelineInjector::Inject(f, merge_async_commit_queue_scope); fptr->body = ConvertSSA(std::move(fptr->body)); return f; }; diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index 0d6d3ff15602..a883a9a251e3 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -175,7 +175,9 @@ def test_async_software_pipeline(hexagon_launcher, comp_type, data, reference, s b_np = data[2] ref = reference(a_np, b_np) - with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + with tvm.transform.PassContext( + config={"tir.use_async_copy": 1, "tir.merge_async_commit_queue_scope": False} + ): # tvm.lower(schedule.mod["main"]).show() func = tvm.build(schedule.mod["main"], target=get_hexagon_target("v68")) From 24c01e74ab94a540afdce80649c243190c7b70ed Mon Sep 17 00:00:00 2001 From: adstraw Date: Wed, 12 Oct 2022 11:09:38 -0700 Subject: [PATCH 3/3] move merge_async_commit_queue_scope option select inside pass --- include/tvm/tir/transform.h | 2 +- src/driver/driver_api.cc | 4 +--- src/meta_schedule/postproc/verify_gpu_code.cc | 3 +-- src/tir/transforms/inject_software_pipeline.cc | 4 +++- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 0bf7d768cfaa..6aa1aca69970 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -629,7 +629,7 @@ TVM_DLL Pass UnifiedStaticMemoryPlanner(); * * \return The IR transform pass. */ -TVM_DLL Pass InjectSoftwarePipeline(bool merge_async_commit_queue_scope); +TVM_DLL Pass InjectSoftwarePipeline(); TVM_DLL Pass BindParams(const Array& constants); diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 19c3f1844d15..5f8c8742695d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -203,9 +203,7 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); - bool merge_async_commit_queue_scope = - pass_ctx->GetConfig("tir.merge_async_commit_queue_scope", Bool(true)).value(); - pass_list.push_back(tir::transform::InjectSoftwarePipeline(merge_async_commit_queue_scope)); + pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16Legalize()); diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 989acbafa047..0828ee538427 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -163,8 +163,7 @@ class VerifyGPUCodeNode : public PostprocNode { pass_list.push_back(tir::transform::UnifyThreadBinding()); pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); - pass_list.push_back( - tir::transform::InjectSoftwarePipeline(true /* merge_async_commit_queue_scope */)); + pass_list.push_back(tir::transform::InjectSoftwarePipeline()); pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); pass_list.push_back(tir::transform::BF16Legalize()); diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index ac26b8dfab20..51523a37399b 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -1222,9 +1222,11 @@ namespace transform { * \brief Transform annotated loops into pipelined one that parallelize producers and consumers. * \return The IR transform pass. */ -Pass InjectSoftwarePipeline(bool merge_async_commit_queue_scope) { +Pass InjectSoftwarePipeline() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* fptr = f.CopyOnWrite(); + bool merge_async_commit_queue_scope = + ctx->GetConfig("tir.merge_async_commit_queue_scope", Bool(true)).value(); fptr->body = software_pipeline::PipelineInjector::Inject(f, merge_async_commit_queue_scope); fptr->body = ConvertSSA(std::move(fptr->body)); return f;