diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index f2486356cce1..b48fbe44bd11 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -509,7 +509,7 @@ class NameMangleExtFuncs : public MixedModeMutator { // Walk the tree and mangle the functions. Then replace compiler functions // with mangled functions in the module - IRModule new_module; + IRModule new_module = IRModule({}, module_->type_definitions, module_->Imports()); for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 55b150d948c1..29d420def184 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -1439,6 +1439,34 @@ def Optimize(mod): tvm.testing.assert_allclose(t0.body.data.numpy(), expected, rtol=1e-5, atol=1e-5) +def test_preserve_type_import(): + """Test to make sure type definition and imports are preserved during the BYOC pipeline.""" + from tvm.relay.prelude import Prelude, StaticTensorArrayOps + + def run(dtype, shape): + mod = tvm.IRModule() + p = Prelude(mod) + static_tensor_array_ops = StaticTensorArrayOps(p, dtype, shape) + static_tensor_array_ops.register() + + tensor_array = p.get_global_var_static("tensor_array", dtype, shape) + tensor = p.get_tensor_ctor_static("tensor_constructor", dtype, shape) + write = p.get_global_var_static("tensor_array_write", dtype, shape) + gather = p.get_global_var_static("tensor_array_gather", dtype, shape) + v = relay.var("v") + indice = relay.var("indice") + init_tensor_array = tensor_array(relay.const(3)) + tensor_array1 = write(init_tensor_array, relay.const(0), tensor(v)) + tensor_array2 = write(tensor_array1, relay.const(1), tensor(v)) + tensor_array3 = write(tensor_array2, relay.const(2), tensor(v)) + out = gather(tensor_array3, indice) + mod["main"] = relay.Function([v, indice], out) + mod = transform.RemoveUnusedFunctions()(mod) + mod = transform.PartitionGraph()(mod) + + run("float32", [2, 3]) + + if __name__ == "__main__": test_multi_node_compiler() test_extern_ccompiler_single_op() @@ -1460,3 +1488,4 @@ def Optimize(mod): test_flatten_tuple_output() test_tuple_output_exec() test_extern_opt() + test_static_tensor_array_gather_partition()