From 652068838506d2c088b97d3bfcee65656086eee9 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 24 Jun 2020 09:52:07 -0700 Subject: [PATCH 1/5] Add LegalizeInvalidAttach --- .gitignore | 1 + src/te/schedule/schedule_dataflow_rewrite.cc | 76 +++++++++++++++++++- tests/python/unittest/test_te_schedule.py | 20 ++++++ 3 files changed, 95 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index b9357018a64c..506e54d93067 100644 --- a/.gitignore +++ b/.gitignore @@ -196,6 +196,7 @@ tvm_t.* .python_history .pytest_cache .local +cmake-build-debug # Visual Studio Code .vscode diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index af72d3b1a1df..acd594df9663 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -451,7 +451,7 @@ Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) { } } -void RebaseNonZeroMinLoop(const Schedule& sch) { +void RebaseNonZeroMinLoop(ScheduleNode* sch) { std::unordered_map rebase_map; for (Stage s : sch->stages) { if (s->attach_type == kInlinedAlready) continue; @@ -614,10 +614,82 @@ void InjectInline(ScheduleNode* sch) { } } +void LegalizeInvalidAttach(ScheduleNode* sch) { + // Legalize the compute_at location if the target iterator of compute_at was split or fused. + // Case 1: If the target of compute_at is split, + // we will move the compute_at location to the inner iterator. + // Case 2: If the target of compute_at is fused, + // we will move the compute_at location to the newly fused iterator. + // Note that case 2 can only happen if the target of compute_at + // is the innermost operands of fuse operation. + + std::unordered_map replace_map; + + for (Stage stage : sch->stages) { + for (Stage s = stage; s.defined();) { + Stage spec = s.GetAttachSpec(); + if (spec->attach_type != kScope) { + break; + } + bool start_attach = false; + IterVar attach_ivar = spec->attach_ivar; + s = spec->attach_stage; + CHECK(attach_ivar.defined()); + CHECK(s.defined()); + + for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) { + IterVar iv = s->leaf_iter_vars[i - 1]; + if (!start_attach && iv.same_as(attach_ivar)) { + start_attach = true; + } + } + + if (!start_attach) { + IterVar new_attach_ivar = attach_ivar;; + bool updated = true; + // recursively update the relations + while (updated) { + updated = false; + for (const auto& rel : s->relations) { + if (const FuseNode* r = rel.as()) { + if (new_attach_ivar.same_as(r->inner)) { + new_attach_ivar = r->fused; + updated = true; + } + } else if (const SplitNode* r = rel.as()) { + if (new_attach_ivar.same_as(r->parent)) { + new_attach_ivar = r->inner; + updated = true; + } + } + } + replace_map[attach_ivar] = new_attach_ivar; + } + } + } + } + + // remap the parent relation + for (Stage s : sch->stages) { + if (s->attach_type != kScope) continue; + if (replace_map.count(s->attach_ivar)) { + s->attach_ivar = replace_map.at(s->attach_ivar); + } + } + for (Stage s : sch->groups) { + if (s->attach_type != kScope) continue; + if (replace_map.count(s->attach_ivar)) { + s->attach_ivar = replace_map.at(s->attach_ivar); + } + } +} + + Schedule Schedule::normalize() { Schedule sn = copy(); InjectInline(sn.operator->()); - RebaseNonZeroMinLoop(sn); + RebaseNonZeroMinLoop(sn.operator->()); + LegalizeInvalidAttach(sn.operator->()); return sn; } diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index 2c851cc39789..c00ee70586ef 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -289,6 +289,25 @@ def intrin_func(ins, outs, sp): assert str(stmt.body.body.value.args[3]) == "(i: int32*i)" assert str(stmt.body.body.value.args[4]) == "(i: int32 + j: int32)" +def test_legalize_invalid_attach(): + A = te.compute((10, 10), lambda i, j: 1.0, name='A') + B = te.compute((10, 10), lambda i, j: A[i][j], name='B') + + # Case 1: Split an axis which is the target of a compute_at + s = te.create_schedule([B.op]) + s[A].compute_at(s[B], B.op.axis[1]) + s[B].split(B.op.axis[1], 2) + + stmt = tvm.lower(s, [A, B], simple_mode=True)['main'].body + assert isinstance(stmt.body.body, tvm.tir.stmt.For) + + # Case 2: Fuse an axis which is the target of a compute_at + s = te.create_schedule([B.op]) + s[A].compute_at(s[B], B.op.axis[1]) + s[B].fuse(B.op.axis[0], B.op.axis[1]) + stmt = tvm.lower(s, [A, B], simple_mode=True)['main'].body + assert isinstance(stmt, tvm.tir.stmt.For) + if __name__ == "__main__": test_singleton() test_pragma() @@ -305,3 +324,4 @@ def intrin_func(ins, outs, sp): test_fuse_with_out_of_order_axis_with_reorder() test_vectorize() test_vectorize_commreduce() + test_legalize_invalid_attach() From 116d8a6a04fa437bd232a5921380eae20875c2e6 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 24 Jun 2020 09:59:06 -0700 Subject: [PATCH 2/5] lint & typo --- src/te/schedule/schedule_dataflow_rewrite.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index acd594df9663..1dda4fafcf2c 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -615,13 +615,13 @@ void InjectInline(ScheduleNode* sch) { } void LegalizeInvalidAttach(ScheduleNode* sch) { - // Legalize the compute_at location if the target iterator of compute_at was split or fused. + // Legalize the compute_at location if the target iterator of compute_at is split or fused. // Case 1: If the target of compute_at is split, // we will move the compute_at location to the inner iterator. // Case 2: If the target of compute_at is fused, // we will move the compute_at location to the newly fused iterator. // Note that case 2 can only happen if the target of compute_at - // is the innermost operands of fuse operation. + // is the innermost operand of fuse operation. std::unordered_map replace_map; @@ -645,7 +645,7 @@ void LegalizeInvalidAttach(ScheduleNode* sch) { } if (!start_attach) { - IterVar new_attach_ivar = attach_ivar;; + IterVar new_attach_ivar = attach_ivar; bool updated = true; // recursively update the relations while (updated) { From 7a0aba34731ebcebbdab019f6db6b5af44a84540 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 24 Jun 2020 10:02:08 -0700 Subject: [PATCH 3/5] lint & typo --- src/te/schedule/schedule_dataflow_rewrite.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 1dda4fafcf2c..d70d5f574554 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -684,7 +684,6 @@ void LegalizeInvalidAttach(ScheduleNode* sch) { } } - Schedule Schedule::normalize() { Schedule sn = copy(); InjectInline(sn.operator->()); From 518561748c752098e6f4f4bbcdb9d9e04ff3682a Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 24 Jun 2020 20:17:10 -0700 Subject: [PATCH 4/5] address comment --- src/te/schedule/schedule_dataflow_rewrite.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index d70d5f574554..173369ecca24 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -622,11 +622,14 @@ void LegalizeInvalidAttach(ScheduleNode* sch) { // we will move the compute_at location to the newly fused iterator. // Note that case 2 can only happen if the target of compute_at // is the innermost operand of fuse operation. - + + // Map an old invalid attach point to its new valid attach point std::unordered_map replace_map; for (Stage stage : sch->stages) { for (Stage s = stage; s.defined();) { + // The following logic is simiar to the `CreateAttachPath` in `src/te/schedule/graph.h`, + // because we follow the validation check in that function to legalize the attach. Stage spec = s.GetAttachSpec(); if (spec->attach_type != kScope) { break; @@ -641,6 +644,7 @@ void LegalizeInvalidAttach(ScheduleNode* sch) { IterVar iv = s->leaf_iter_vars[i - 1]; if (!start_attach && iv.same_as(attach_ivar)) { start_attach = true; + break; } } From c9ec14a64100bdcd228ffae0bac28cc0ae584727 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 24 Jun 2020 20:24:51 -0700 Subject: [PATCH 5/5] fix lint --- src/te/schedule/schedule_dataflow_rewrite.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index 173369ecca24..f130cb438113 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -622,7 +622,7 @@ void LegalizeInvalidAttach(ScheduleNode* sch) { // we will move the compute_at location to the newly fused iterator. // Note that case 2 can only happen if the target of compute_at // is the innermost operand of fuse operation. - + // Map an old invalid attach point to its new valid attach point std::unordered_map replace_map;