Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions src/tir/transforms/inject_software_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,10 @@ class PipelineRewriter : public StmtExprMutator {
const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& double_buffers,
const Array<Buffer> pipeline_allocs, const For& pipeline_loop,
const PipelineInfo& pipeline_info,
const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info) {
const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info,
const Map<String, ObjectRef> preserved_annotations) {
PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers, pipeline_allocs, pipeline_loop,
pipeline_info, fragment_info);
pipeline_info, fragment_info, preserved_annotations);
return rewriter.BuildPipeline();
}

Expand All @@ -319,14 +320,16 @@ class PipelineRewriter : public StmtExprMutator {
const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>& double_buffers,
const Array<Buffer>& pipeline_allocs, const For& pipeline_loop,
const PipelineInfo& pipeline_info,
const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info)
const std::unordered_map<const VarNode*, FragmentInfo>& fragment_info,
const Map<String, ObjectRef> preserved_annotations)

: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
double_buffers_(double_buffers),
pipeline_allocs_(pipeline_allocs),
pipeline_loop_(pipeline_loop),
pipeline_info_(pipeline_info),
fragment_info_(fragment_info) {}
fragment_info_(fragment_info),
preserved_annotations_(preserved_annotations) {}

Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the number of versions
Expand Down Expand Up @@ -903,7 +906,8 @@ class PipelineRewriter : public StmtExprMutator {

if (!is_unit_loop) {
new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop));
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop),
NullOpt, preserved_annotations_);
}

// Update producer heads in the global async states.
Expand Down Expand Up @@ -937,6 +941,7 @@ class PipelineRewriter : public StmtExprMutator {
Map<Buffer, Buffer> buffer_remap_;
Array<Block> ordered_stmts_;
std::map<int, AsyncStateGlobal> async_states;
Map<String, ObjectRef> preserved_annotations_;
};

/*!
Expand Down Expand Up @@ -1100,6 +1105,15 @@ class PipelineInjector : private StmtExprMutator {
}
}

Map<String, ObjectRef> preserved_annotations;
for (const auto& kv : op->annotations) {
const String& key = kv.first;
if (kv.first != attr::software_pipeline_stage && kv.first != attr::software_pipeline_order &&
kv.first != attr::software_pipeline_async_stages) {
preserved_annotations.Set(key, kv.second);
}
}

for (size_t i = 0; i < pipeline_stages.size(); i++) {
int stage = static_cast<int>(pipeline_stages[i]->value);
bool is_async = pipeline_async_stages.find(stage) != pipeline_async_stages.end();
Expand All @@ -1112,9 +1126,9 @@ class PipelineInjector : private StmtExprMutator {
ValidatePipelineBody(pipeline_info, original_order);

// Step 4: Rewrite the pipeline body.
Stmt pipeline =
PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers, pipeline_allocs,
GetRef<For>(op), pipeline_info, fragment_info_);
Stmt pipeline = PipelineRewriter::Rewrite(buffer_data_to_buffer_, double_buffers,
pipeline_allocs, GetRef<For>(op), pipeline_info,
fragment_info_, preserved_annotations);

if (const auto* realize = op->body.as<BlockRealizeNode>()) {
const auto& block = realize->block;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,69 @@ def transformed_simple_compute(
C[tx, 15] = B[1, tx, 0] + T.float32(1)


@T.prim_func
def simple_compute_with_other_annotation(
A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]
):
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
for i in T.serial(
0,
16,
annotations={
"software_pipeline_stage": [0, 1],
"software_pipeline_order": [0, 1],
"pragma_loop_partition_hint": True,
},
):
with T.block("compute"):
T.reads(A[tx, i])
T.writes(C[tx, i])
B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
with T.block():
T.reads(A[tx, i])
T.writes(B[tx, 0])
B[tx, 0] = A[tx, i] * T.float32(2)
with T.block():
T.reads(B[tx, 0])
T.writes(C[tx, i])
C[tx, i] = B[tx, 0] + T.float32(1)


@T.prim_func
def transformed_simple_compute_with_other_annotation(
A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]
) -> None:
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
with T.block():
T.reads([A[tx, 0:16]])
T.writes([C[tx, 0:16]])
B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared")
with T.block():
T.reads([A[tx, 0]])
T.writes([B[0, tx, 0]])
B[0, tx, 0] = A[tx, 0] * T.float32(2)
with T.block():
T.reads([A[tx, 1:16], B[0:2, tx, 0]])
T.writes([B[0:2, tx, 0], C[tx, 0:15]])
for i in T.serial(
0,
15,
annotations={"pragma_loop_partition_hint": True},
):
with T.block():
T.reads([A[tx, i + 1]])
T.writes([B[(i + 1) % 2, tx, 0]])
B[(i + 1) % 2, tx, 0] = A[tx, i + 1] * T.float32(2)
with T.block():
T.reads([B[i % 2, tx, 0]])
T.writes([C[tx, i]])
C[tx, i] = B[i % 2, tx, 0] + T.float32(1)
with T.block():
T.reads([B[1, tx, 0]])
T.writes([C[tx, 15]])
C[tx, 15] = B[1, tx, 0] + T.float32(1)


@T.prim_func
def three_stage_compute(A: T.Buffer[(16, 16), "float32"], D: T.Buffer[(16, 16), "float32"]):
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
Expand Down Expand Up @@ -1000,6 +1063,10 @@ def test_simple_compute():
_check(gen_simple_compute(1), transformed_simple_compute)


def test_simple_compute_with_other_annotation():
_check(simple_compute_with_other_annotation, transformed_simple_compute_with_other_annotation)


def test_trivial_pipeline():
_check(trivial_pipeline, transformed_trivial_pipeline)

Expand Down