From 4e25ddfd35531b372c0b550d97f62b5844e73271 Mon Sep 17 00:00:00 2001 From: Thrsu <89128704+Thrsu@users.noreply.github.com> Date: Fri, 1 Nov 2024 22:21:50 +0800 Subject: [PATCH 1/3] Update static_plan_block_memory.cc --- src/relax/transform/static_plan_block_memory.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 74200526b699..f8026dadc6ab 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -730,6 +730,15 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { } } + void VisitBindingBlock_(const DataflowBlockNode* block) override { + // We maintain a block stack for token allocation-site and use-site check. + block_stack_.push_back(block); + ExprVisitor::VisitBindingBlock_(block); + ICHECK(!block_stack_.empty()); + ICHECK(block_stack_.back() == block); + block_stack_.pop_back(); + } + void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); if (call->op == alloc_tensor_op) { From 0e66ba8b5d2c6d5f5f09d773a00e9e9e374aef88 Mon Sep 17 00:00:00 2001 From: Thrsu <89128704+Thrsu@users.noreply.github.com> Date: Sat, 9 Nov 2024 23:16:43 +0800 Subject: [PATCH 2/3] Add test case. --- ...test_transform_static_plan_block_memory.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 1150827b19f9..28015f0eecff 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -1504,5 +1504,46 @@ def main() -> R.Tensor((128,), dtype="float32"): tvm.ir.assert_structural_equal(after, Expected) +def test_with_dataflow(): + @I.ir_module + class Before: + @T.prim_func + def exp(A: T.handle, B: T.handle): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = Before + with R.dataflow(): + alloc: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor( + R.shape([10]), R.dtype("float32"), runtime_device_index=0 + ) + _: R.Tuple() = cls.exp(x, alloc) + gv: R.Tensor((10,), dtype="float32") = alloc + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def exp(A: T.handle, B: T.handle): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = Expected + with R.dataflow(): + alloc: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor( + R.shape([10]), R.dtype("float32"), R.prim_value(0), R.str("global") + ) + cls.exp(x, alloc) + gv: R.Tensor((10,), dtype="float32") = alloc + R.output(gv) + return gv + + after = relax.transform.StaticPlanBlockMemory()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": tvm.testing.main() From 260606c5e4c9157275634a33a57b9d3a15f44433 Mon Sep 17 00:00:00 2001 From: Thrsu <89128704+Thrsu@users.noreply.github.com> Date: Sun, 10 Nov 2024 16:51:33 +0800 Subject: [PATCH 3/3] Update static_plan_block_memory.cc --- .../transform/static_plan_block_memory.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index f8026dadc6ab..44e338cbe8ca 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -314,6 +314,15 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { SetTokens(binding->var.get(), token_map_[binding->value.get()]); } + void VisitBindingBlock_(const DataflowBlockNode* block) override { + // We maintain a block stack for token allocation-site and use-site check. + block_stack_.push_back(block); + ExprVisitor::VisitBindingBlock_(block); + ICHECK(!block_stack_.empty()); + ICHECK(block_stack_.back() == block); + block_stack_.pop_back(); + } + void VisitExpr_(const TupleNode* tuple) final { Array tokens; tokens.reserve(tuple->fields.size()); @@ -730,15 +739,6 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { } } - void VisitBindingBlock_(const DataflowBlockNode* block) override { - // We maintain a block stack for token allocation-site and use-site check. - block_stack_.push_back(block); - ExprVisitor::VisitBindingBlock_(block); - ICHECK(!block_stack_.empty()); - ICHECK(block_stack_.back() == block); - block_stack_.pop_back(); - } - void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); if (call->op == alloc_tensor_op) {