diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index b0af0fb65e16..5f8c8742695d 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -51,6 +51,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool); using runtime::PackedFunc; using runtime::TVMArgs; diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 08d57c53d1c2..51523a37399b 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -309,9 +309,10 @@ class PipelineRewriter : public StmtExprMutator { const Array pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, const std::unordered_map& fragment_info, - const Map preserved_annotations) { + const Map preserved_annotations, bool merge_async_commit_queue_scope) { PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop, - pipeline_info, fragment_info, preserved_annotations); + pipeline_info, fragment_info, preserved_annotations, + merge_async_commit_queue_scope); return rewriter.BuildPipeline(); } @@ -321,7 +322,8 @@ class PipelineRewriter : public StmtExprMutator { const Array& pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, const std::unordered_map& fragment_info, - const Map preserved_annotations) + const Map preserved_annotations, + bool merge_async_commit_queue_scope) : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), double_buffers_(double_buffers), @@ -329,7 +331,8 @@ class PipelineRewriter : public StmtExprMutator { pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info), fragment_info_(fragment_info), - preserved_annotations_(preserved_annotations) {} + preserved_annotations_(preserved_annotations), + merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {} Stmt BuildPipeline() { // Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions @@ -762,11 +765,19 @@ class PipelineRewriter : public StmtExprMutator { << "Predicates in the same stage are expected to be identical"; group_bodies.push_back(new_blocks[i].block->body); } - auto body = group_bodies.size() > 1 ? SeqStmt(group_bodies) : group_bodies[0]; - auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), - tir::attr::async_commit_queue_scope, stage_id, body); - auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); - stmts.push_back(BlockRealize({}, predicate, new_block)); + + if (merge_async_commit_queue_scope_ && group_bodies.size() > 1) { + auto merged_bodies = SeqStmt(group_bodies); + group_bodies.clear(); + group_bodies.push_back(merged_bodies); + } + + for (auto body : group_bodies) { + auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), + tir::attr::async_commit_queue_scope, stage_id, body); + auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); + stmts.push_back(BlockRealize({}, predicate, new_block)); + } } } @@ -842,7 +853,8 @@ class PipelineRewriter : public StmtExprMutator { auto& local_state = async_states_local[stage]; int commit_group_id = -1; - if (local_state.commit_groups.empty() || local_state.consumed) { + if (local_state.commit_groups.empty() || local_state.consumed || + !merge_async_commit_queue_scope_) { // consumed == true means there is already a consumer stage waiting for an // eariler async operation of this stage. In such cases, we make multiple commit_queue // for this stage. @@ -942,6 +954,7 @@ class PipelineRewriter : public StmtExprMutator { Array ordered_stmts_; std::map async_states; Map preserved_annotations_; + bool merge_async_commit_queue_scope_ = true; }; /*! @@ -980,8 +993,8 @@ void BuildDependencyGraph( class PipelineInjector : private StmtExprMutator { public: - static Stmt Inject(const PrimFunc& func) { - PipelineInjector injector; + static Stmt Inject(const PrimFunc& func, bool merge_async_commit_queue_scope) { + PipelineInjector injector(merge_async_commit_queue_scope); for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; injector.buffer_data_to_buffer_.Set(buffer->data, buffer); @@ -991,7 +1004,8 @@ class PipelineInjector : private StmtExprMutator { } private: - PipelineInjector() = default; + explicit PipelineInjector(bool merge_async_commit_queue_scope) + : merge_async_commit_queue_scope_(merge_async_commit_queue_scope) {} /*! * \brief Check the pipeline satisfies the following conditions: @@ -1126,9 +1140,9 @@ class PipelineInjector : private StmtExprMutator { ValidatePipelineBody(pipeline_info, original_order); // Step 4: Rewrite the pipeline body. - Stmt pipeline = PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers, - pipeline_allocs, GetRef(op), pipeline_info, - fragment_info_, preserved_annotations); + Stmt pipeline = PipelineRewriter::Rewrite( + buffer_data_to_buffer_, double_buffers, pipeline_allocs, GetRef(op), pipeline_info, + fragment_info_, preserved_annotations, merge_async_commit_queue_scope_); if (const auto* realize = op->body.as()) { const auto& block = realize->block; @@ -1197,6 +1211,7 @@ class PipelineInjector : private StmtExprMutator { Map buffer_data_to_buffer_; std::unordered_map fragment_info_; std::unordered_set double_buffers; + bool merge_async_commit_queue_scope_ = true; }; } // namespace software_pipeline @@ -1210,7 +1225,9 @@ namespace transform { Pass InjectSoftwarePipeline() { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* fptr = f.CopyOnWrite(); - fptr->body = software_pipeline::PipelineInjector::Inject(f); + bool merge_async_commit_queue_scope = + ctx->GetConfig("tir.merge_async_commit_queue_scope", Bool(true)).value(); + fptr->body = software_pipeline::PipelineInjector::Inject(f, merge_async_commit_queue_scope); fptr->body = ConvertSSA(std::move(fptr->body)); return f; }; diff --git a/src/tir/transforms/lower_async_dma.cc b/src/tir/transforms/lower_async_dma.cc index 78d363f67c02..417e9d61f263 100644 --- a/src/tir/transforms/lower_async_dma.cc +++ b/src/tir/transforms/lower_async_dma.cc @@ -94,9 +94,6 @@ class AsyncDMALowerer : public StmtExprMutator { ICHECK(queue_id_node); int queue_id = queue_id_node->value; - // save queue ID for inspection in `wait` transform - queue_ids.insert(queue_id); - // walk the graph to verify this is a mem copy ... // 1) async_commit_queue_scope contains async_scope auto async_scope = op->body.as(); @@ -161,6 +158,10 @@ class AsyncDMALowerer : public StmtExprMutator { return analyzer.Simplify(Substitute(std::move(expr), loop_var_remap)); }); + // now that we are about to perform the `copy` transform + // save queue ID for inspection in `wait` transform + queue_ids.insert(queue_id); + return Evaluate(Call(DataType::Int(32), builtin::dma_copy(), {queue_id, Call(DataType::Handle(), builtin::address_of(), diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index 943d4262f9da..a883a9a251e3 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -30,81 +30,169 @@ inner = tvm.testing.parameter(64, 128) dtype = tvm.testing.parameter("uint8", "float16") scope = tvm.testing.parameter("global", "global.vtcm") +# TODO(Joseph) Turn on "multi_input_diffQ" compute type once we have upstreamed +# changes in the InjectSoftwarePipeline pass to alleviate this restriction: +# 'A dependency on multiple async stages is not supported' +comp_type = tvm.testing.parameter("single_input", "multi_input_sameQ") # TODO(Straw) Add back "cache_write" schedule type once we have upstreamed # buffer dependency analysis in InjectSoftwarePipeline pass # to insert approprite TIR "wait" attributes for this schedule -sched = tvm.testing.parameter("cache_read", "cache_read_write") +sched_type = tvm.testing.parameter("cache_read", "cache_read_write") @tvm.testing.fixture -def compute(outer, inner, dtype): - @T.prim_func - def plus_one_primfunc(A: T.Buffer[(outer, inner), dtype], B: T.Buffer[(outer, inner), dtype]): - for i in T.serial(outer): - for j in T.serial(inner): - with T.block("compute"): - with T.block(): - B[i, j] = A[i, j] + T.cast(1, dtype) +def data(comp_type, outer, inner, dtype): + out_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) + a_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) + if comp_type == "single_input": + return out_np, a_np + else: + b_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) + return out_np, a_np, b_np + + +@tvm.testing.fixture +def compute(comp_type, outer, inner, dtype): + if comp_type == "single_input": + + @T.prim_func + def a_plus_1_primfunc( + A: T.Buffer[(outer, inner), dtype], OUT: T.Buffer[(outer, inner), dtype] + ): + for i in T.serial(outer): + for j in T.serial(inner): + with T.block("compute"): + with T.block(): + OUT[i, j] = A[i, j] + T.cast(1, dtype) + + return a_plus_1_primfunc + else: + + @T.prim_func + def a_plus_b_plus_1_primfunc( + A: T.Buffer[(outer, inner), dtype], + B: T.Buffer[(outer, inner), dtype], + OUT: T.Buffer[(outer, inner), dtype], + ): + for i in T.serial(outer): + for j in T.serial(inner): + with T.block("compute"): + with T.block(): + OUT[i, j] = A[i, j] + B[i, j] + T.cast(1, dtype) + + return a_plus_b_plus_1_primfunc + + +@tvm.testing.fixture +def reference(comp_type): + if comp_type == "single_input": + + def a_plus_1_ref(a): + return a + 1 + + return a_plus_1_ref + else: - def plus_one_ref(a): - return a + 1 + def a_plus_b_plus_1_ref(a, b): + return a + b + 1 - return plus_one_primfunc, plus_one_ref + return a_plus_b_plus_1_ref @tvm.testing.fixture -def schedule(compute, sched, scope): - sch = tir.Schedule(compute[0]) +def schedule(comp_type, compute, sched_type, scope): + sch = tir.Schedule(compute) compute_block = sch.get_block("compute") i, _ = sch.get_loops(compute_block) - if sched == "cache_read": - cache_read_block = sch.cache_read(compute_block, 0, scope) - sch.compute_at(cache_read_block, i) - sch.annotate(i, "software_pipeline_stage", [0, 1]) - sch.annotate(i, "software_pipeline_order", [0, 1]) - sch.annotate(i, "software_pipeline_async_stages", [0]) - elif sched == "cache_write": - cache_write_block = sch.cache_write(compute_block, 0, scope) - sch.reverse_compute_at(cache_write_block, i) + if "read" in sched_type: + cache_read_a = sch.cache_read(compute_block, 0, scope) + sch.compute_at(cache_read_a, i) + + if "multi_input" in comp_type: + cache_read_b = sch.cache_read(compute_block, 1, scope) + sch.compute_at(cache_read_b, i) + + if "write" in sched_type: + cache_write_out = sch.cache_write(compute_block, 0, scope) + sch.reverse_compute_at(cache_write_out, i) + + if "read" in sched_type and "write" in sched_type: + if comp_type == "single_input": + sch.annotate(i, "software_pipeline_stage", [0, 1, 2]) + sch.annotate(i, "software_pipeline_order", [0, 1, 2]) + sch.annotate(i, "software_pipeline_async_stages", [0, 2]) + elif comp_type == "multi_input_sameQ": + sch.annotate(i, "software_pipeline_stage", [0, 0, 1, 2]) + sch.annotate(i, "software_pipeline_order", [0, 1, 2, 3]) + sch.annotate(i, "software_pipeline_async_stages", [0, 2]) + elif comp_type == "multi_input_diffQ": + sch.annotate(i, "software_pipeline_stage", [0, 1, 2, 3]) + sch.annotate(i, "software_pipeline_order", [0, 1, 2, 3]) + sch.annotate(i, "software_pipeline_async_stages", [0, 1, 2]) + + elif "read" in sched_type: + if comp_type == "single_input": + sch.annotate(i, "software_pipeline_stage", [0, 1]) + sch.annotate(i, "software_pipeline_order", [0, 1]) + sch.annotate(i, "software_pipeline_async_stages", [0]) + elif comp_type == "multi_input_sameQ": + sch.annotate(i, "software_pipeline_stage", [0, 0, 1]) + sch.annotate(i, "software_pipeline_order", [0, 1, 2]) + sch.annotate(i, "software_pipeline_async_stages", [0]) + elif comp_type == "multi_input_diffQ": + sch.annotate(i, "software_pipeline_stage", [0, 1, 2]) + sch.annotate(i, "software_pipeline_order", [0, 1, 2]) + sch.annotate(i, "software_pipeline_async_stages", [0, 1]) + + elif "write" in sched_type: sch.annotate(i, "software_pipeline_stage", [0, 1]) sch.annotate(i, "software_pipeline_order", [0, 1]) sch.annotate(i, "software_pipeline_async_stages", [1]) - elif sched == "cache_read_write": - cache_read_block = sch.cache_read(compute_block, 0, scope) - sch.compute_at(cache_read_block, i) - cache_write_block = sch.cache_write(compute_block, 0, scope) - sch.reverse_compute_at(cache_write_block, i) - sch.annotate(i, "software_pipeline_stage", [0, 1, 2]) - sch.annotate(i, "software_pipeline_order", [0, 1, 2]) - sch.annotate(i, "software_pipeline_async_stages", [0, 2]) return sch -@tvm.testing.requires_hexagon -def test_async_software_pipeline(hexagon_launcher, compute, schedule, outer, inner, dtype, scope): - sch = schedule +@tvm.testing.fixture +def verify(dtype): + def check(out, ref): + if "int" in dtype: + np.testing.assert_equal(out.numpy(), ref) + else: + np.testing.assert_allclose(out.numpy(), ref, rtol=1e-3, atol=1e-3) + + return check - a_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) - b_np = np.random.uniform(low=0, high=128, size=(outer, inner)).astype(dtype) - ref = compute[1](a_np) - with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): - func = tvm.build(sch.mod["main"], target=get_hexagon_target("v68")) +@tvm.testing.requires_hexagon +def test_async_software_pipeline(hexagon_launcher, comp_type, data, reference, schedule, verify): + out_np = data[0] + a_np = data[1] + if comp_type == "single_input": + ref = reference(a_np) + else: + b_np = data[2] + ref = reference(a_np, b_np) + + with tvm.transform.PassContext( + config={"tir.use_async_copy": 1, "tir.merge_async_commit_queue_scope": False} + ): + # tvm.lower(schedule.mod["main"]).show() + func = tvm.build(schedule.mod["main"], target=get_hexagon_target("v68")) with hexagon_launcher.start_session() as hexagon_session: dev = hexagon_session.device - a = tvm.nd.array(a_np, device=dev) - b = tvm.nd.array(b_np, device=dev) mod = hexagon_session.load_module(func) - mod(a, b) - - if "int" in dtype: - np.testing.assert_equal(b.numpy(), ref) + out = tvm.nd.array(out_np, device=dev) + a = tvm.nd.array(a_np, device=dev) + if comp_type == "single_input": + mod(a, out) else: - np.testing.assert_allclose(b.numpy(), ref, rtol=1e-3, atol=1e-3) + b = tvm.nd.array(b_np, device=dev) + mod(a, b, out) + + verify(out, ref) if __name__ == "__main__":