diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 3f0dfcf149c2..952513db4c3d 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -325,10 +325,12 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { } // - Increase the reference counters of the arguments when the callee is - // a PrimFunc of the context module. + // a PrimFunc of the context module or an external function via 'call_packed'. + // It assumes external function calls via 'call_packed' do not retain memory + // from the arguments. // - Otherwise, discard the tokens used by the arguments, as there might be // potential external reference. - if (IsPrimFuncGlobalVar(call->op)) { + if (IsPrimFuncGlobalVar(call->op) || call->op->IsInstance()) { ICHECK(!block_stack_.empty()); for (const Expr& arg : call->args) { Tokens tokens = GetTokensWithAllocSiteCheck(arg, block_stack_.back()); diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 5198d9e07525..521fcc1924e7 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -586,6 +586,48 @@ def main(x: R.Tensor((2, 3), "float32")): tvm.ir.assert_structural_equal(mod, Module) +def test_call_packed_external_func(): + @I.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")): + alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _ = R.call_packed("extern_func", x, alloc, sinfo_args=[R.Tuple()]) + y: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), dtype="float32", runtime_device_index=0 + ) + _1 = R.call_packed("extern_func", y, alloc1, sinfo_args=[R.Tuple()]) + z: R.Tensor((2, 3), dtype="float32") = alloc1 + return z + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + storage: R.Object = R.memory.alloc_storage( + R.shape([24]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor( + storage, R.prim_value(0), R.shape([2, 3]), R.dtype("float32") + ) + _: R.Tuple = R.call_packed("extern_func", x, alloc, sinfo_args=(R.Tuple(),)) + y: R.Tensor((2, 3), dtype="float32") = alloc + alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor( + R.shape([2, 3]), R.dtype("float32"), R.prim_value(0) + ) + _1: R.Tuple = R.call_packed("extern_func", y, alloc1, sinfo_args=(R.Tuple(),)) + _2: R.Tuple = R.memory.kill_tensor(alloc) + z: R.Tensor((2, 3), dtype="float32") = alloc1 + _3: R.Tuple = R.memory.kill_storage(storage) + return z + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + def test_symbolic_shape(): @tvm.script.ir_module class Module: