Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,12 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
}

// - Increase the reference counters of the arguments when the callee is
// a PrimFunc of the context module.
// a PrimFunc of the context module or an external function via 'call_packed'.
// It assumes external function calls via 'call_packed' do not retain memory
// from the arguments.
// - Otherwise, discard the tokens used by the arguments, as there might be
// potential external reference.
if (IsPrimFuncGlobalVar(call->op)) {
if (IsPrimFuncGlobalVar(call->op) || call->op->IsInstance<ExternFuncNode>()) {
ICHECK(!block_stack_.empty());
for (const Expr& arg : call->args) {
Tokens tokens = GetTokensWithAllocSiteCheck(arg, block_stack_.back());
Expand Down
42 changes: 42 additions & 0 deletions tests/python/relax/test_transform_static_plan_block_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,48 @@ def main(x: R.Tensor((2, 3), "float32")):
tvm.ir.assert_structural_equal(mod, Module)


def test_call_packed_external_func():
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((2, 3), "float32")):
alloc: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
R.shape([2, 3]), dtype="float32", runtime_device_index=0
)
_ = R.call_packed("extern_func", x, alloc, sinfo_args=[R.Tuple()])
y: R.Tensor((2, 3), dtype="float32") = alloc
alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
R.shape([2, 3]), dtype="float32", runtime_device_index=0
)
_1 = R.call_packed("extern_func", y, alloc1, sinfo_args=[R.Tuple()])
z: R.Tensor((2, 3), dtype="float32") = alloc1
return z

@I.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
storage: R.Object = R.memory.alloc_storage(
R.shape([24]), R.prim_value(0), R.str("global"), R.dtype("float32")
)
alloc: R.Tensor((2, 3), dtype="float32") = R.memory.alloc_tensor(
storage, R.prim_value(0), R.shape([2, 3]), R.dtype("float32")
)
_: R.Tuple = R.call_packed("extern_func", x, alloc, sinfo_args=(R.Tuple(),))
y: R.Tensor((2, 3), dtype="float32") = alloc
alloc1: R.Tensor((2, 3), dtype="float32") = R.builtin.alloc_tensor(
R.shape([2, 3]), R.dtype("float32"), R.prim_value(0)
)
_1: R.Tuple = R.call_packed("extern_func", y, alloc1, sinfo_args=(R.Tuple(),))
_2: R.Tuple = R.memory.kill_tensor(alloc)
z: R.Tensor((2, 3), dtype="float32") = alloc1
_3: R.Tuple = R.memory.kill_storage(storage)
return z

mod = relax.transform.StaticPlanBlockMemory()(Module)
tvm.ir.assert_structural_equal(mod, Expected)


def test_symbolic_shape():
@tvm.script.ir_module
class Module:
Expand Down