From e200ad6a0d88ff1840bd4fcca161519ca053ac64 Mon Sep 17 00:00:00 2001 From: qsqqsqqsq-intellif <2628869@qq.com> Date: Thu, 29 Sep 2022 03:57:51 +0000 Subject: [PATCH] [TIR] Preserve loop annotations in inject_software_pipeline pass --- .../transforms/inject_software_pipeline.cc | 30 ++++++--- ..._tir_transform_inject_software_pipeline.py | 67 +++++++++++++++++++ 2 files changed, 89 insertions(+), 8 deletions(-) diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 2d97aa1a1158..08d57c53d1c2 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -308,9 +308,10 @@ class PipelineRewriter : public StmtExprMutator { const std::unordered_set& double_buffers, const Array pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, - const std::unordered_map& fragment_info) { + const std::unordered_map& fragment_info, + const Map preserved_annotations) { PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop, - pipeline_info, fragment_info); + pipeline_info, fragment_info, preserved_annotations); return rewriter.BuildPipeline(); } @@ -319,14 +320,16 @@ class PipelineRewriter : public StmtExprMutator { const std::unordered_set& double_buffers, const Array& pipeline_allocs, const For& pipeline_loop, const PipelineInfo& pipeline_info, - const std::unordered_map& fragment_info) + const std::unordered_map& fragment_info, + const Map preserved_annotations) : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), double_buffers_(double_buffers), pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info), - fragment_info_(fragment_info) {} + fragment_info_(fragment_info), + preserved_annotations_(preserved_annotations) {} Stmt BuildPipeline() { // Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions @@ -903,7 +906,8 @@ class PipelineRewriter : public StmtExprMutator { if (!is_unit_loop) { new_loop = For(Downcast(new_loop_var), pipeline_loop_->min, extent, - unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop)); + unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop), + NullOpt, preserved_annotations_); } // Update producer heads in the global async states. @@ -937,6 +941,7 @@ class PipelineRewriter : public StmtExprMutator { Map buffer_remap_; Array ordered_stmts_; std::map async_states; + Map preserved_annotations_; }; /*! @@ -1100,6 +1105,15 @@ class PipelineInjector : private StmtExprMutator { } } + Map preserved_annotations; + for (const auto& kv : op->annotations) { + const String& key = kv.first; + if (kv.first != attr::software_pipeline_stage && kv.first != attr::software_pipeline_order && + kv.first != attr::software_pipeline_async_stages) { + preserved_annotations.Set(key, kv.second); + } + } + for (size_t i = 0; i < pipeline_stages.size(); i++) { int stage = static_cast(pipeline_stages[i]->value); bool is_async = pipeline_async_stages.find(stage) != pipeline_async_stages.end(); @@ -1112,9 +1126,9 @@ class PipelineInjector : private StmtExprMutator { ValidatePipelineBody(pipeline_info, original_order); // Step 4: Rewrite the pipeline body. - Stmt pipeline = - PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers, pipeline_allocs, - GetRef(op), pipeline_info, fragment_info_); + Stmt pipeline = PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers, + pipeline_allocs, GetRef(op), pipeline_info, + fragment_info_, preserved_annotations); if (const auto* realize = op->body.as()) { const auto& block = realize->block; diff --git a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py index 49255e0f2094..9334a4d9e827 100644 --- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py +++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py @@ -151,6 +151,69 @@ def transformed_simple_compute( C[tx, 15] = B[1, tx, 0] + T.float32(1) +@T.prim_func +def simple_compute_with_other_annotation( + A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"] +): + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + for i in T.serial( + 0, + 16, + annotations={ + "software_pipeline_stage": [0, 1], + "software_pipeline_order": [0, 1], + "pragma_loop_partition_hint": True, + }, + ): + with T.block("compute"): + T.reads(A[tx, i]) + T.writes(C[tx, i]) + B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") + with T.block(): + T.reads(A[tx, i]) + T.writes(B[tx, 0]) + B[tx, 0] = A[tx, i] * T.float32(2) + with T.block(): + T.reads(B[tx, 0]) + T.writes(C[tx, i]) + C[tx, i] = B[tx, 0] + T.float32(1) + + +@T.prim_func +def transformed_simple_compute_with_other_annotation( + A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"] +) -> None: + for tx in T.thread_binding(0, 16, thread="threadIdx.x"): + with T.block(): + T.reads([A[tx, 0:16]]) + T.writes([C[tx, 0:16]]) + B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") + with T.block(): + T.reads([A[tx, 0]]) + T.writes([B[0, tx, 0]]) + B[0, tx, 0] = A[tx, 0] * T.float32(2) + with T.block(): + T.reads([A[tx, 1:16], B[0:2, tx, 0]]) + T.writes([B[0:2, tx, 0], C[tx, 0:15]]) + for i in T.serial( + 0, + 15, + annotations={"pragma_loop_partition_hint": True}, + ): + with T.block(): + T.reads([A[tx, i + 1]]) + T.writes([B[(i + 1) % 2, tx, 0]]) + B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2) + with T.block(): + T.reads([B[i % 2, tx, 0]]) + T.writes([C[tx, i]]) + C[tx, i] = B[i % 2, tx, 0] + T.float32(1) + with T.block(): + T.reads([B[1, tx, 0]]) + T.writes([C[tx, 15]]) + C[tx, 15] = B[1, tx, 0] + T.float32(1) + + @T.prim_func def three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]): for tx in T.thread_binding(0, 16, thread="threadIdx.x"): @@ -1000,6 +1063,10 @@ def test_simple_compute(): _check(gen_simple_compute(1), transformed_simple_compute) +def test_simple_compute_with_other_annotation(): + _check(simple_compute_with_other_annotation, transformed_simple_compute_with_other_annotation) + + def test_trivial_pipeline(): _check(trivial_pipeline, transformed_trivial_pipeline)