From 61195e159c8500c58a3e01e12b57687c1ae03835 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Mon, 6 Feb 2017 03:18:11 +0800 Subject: [PATCH] fix Stage.fuse --- src/api/api_lang.cc | 2 +- src/schedule/schedule_lang.cc | 3 ++- tests/python/unittest/test_lang_schedule.py | 15 +++++++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 3393228f8104..96c61a76227b 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -216,7 +216,7 @@ TVM_REGISTER_API(_StageFuse) .set_body([](TVMArgs args, TVMRetValue* ret) { IterVar fused; args[0].operator Stage() - .split(args[1], args[2], &fused); + .fuse(args[1], args[2], &fused); *ret = fused; }); diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc index 58368ceb93b4..c84644104ca0 100644 --- a/src/schedule/schedule_lang.cc +++ b/src/schedule/schedule_lang.cc @@ -117,6 +117,7 @@ Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*) IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused"); + *p_target = fused; StageNode* self = operator->(); ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); @@ -129,7 +130,7 @@ Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT CHECK_EQ(pos_inner, pos_outer + 1) << "Can only fuse iterations that are consecutive between each other"; leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer, - leaf_vars->data.begin() + pos_inner); + leaf_vars->data.begin() + pos_inner + 1); leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer, fused.node_); return *this; diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index fcb573dab4c3..04f00751dde0 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -63,8 +63,23 @@ def test_tile(): xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5) assert tuple(s[T].leaf_iter_vars) == (xo, yo, xi, yi) + +def test_fuse(): + m = tvm.Var('m') + n = tvm.Var('n') + A = tvm.placeholder((m, n), name='A') + T = tvm.compute((m, n), lambda i, j: A[i, j]) + + s = tvm.Schedule(T.op) + xo, yo, xi, yi = s[T].tile(T.op.axis[0], T.op.axis[1], x_factor=10, y_factor=5) + fused = s[T].fuse(yo, xo) + assert any(isinstance(x, tvm.schedule.Fuse) for x in s[T].relations) + assert tuple(s[T].leaf_iter_vars) == (fused, xi, yi) + + if __name__ == "__main__": test_schedule_create() test_reorder() test_tile() test_split() + test_fuse()