diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index ad9ba1b2069d..f1398786b93b 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -291,7 +291,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { args.push_back(param_handle); } else { auto var_arg = FindExpr(arg); - args.push_back(var_arg[0]); + for (const auto& var : var_arg) { + args.push_back(var); + } } } diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index e117302d0ed8..73aa385161f6 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -26,6 +26,7 @@ from tvm.ir.module import IRModule from tvm.relay import testing, transform from tvm.relay.testing import byoc +from tvm.relay.op.annotation import compiler_begin, compiler_end from aot_test_utils import ( AOTTestModel, AOT_DEFAULT_RUNNER, @@ -312,8 +313,58 @@ def test_mobilenet(debug_calculated_workspaces, workspace_byte_alignment): ) -def test_byoc_microtvm(): - """This is a simple test case to check BYOC capabilities of AOT""" +@pytest.mark.parametrize("merge_compiler_regions", [False, True]) +def test_byoc_microtvm(merge_compiler_regions): + """This is a simple test to check BYOC capabilities of AOT - with and without merging compiler regions to test for https://github.com/apache/tvm/issues/9036""" + use_unpacked_api = False + interface_api = "packed" + test_runner = AOT_DEFAULT_RUNNER + + x = relay.var("x", shape=(10, 10)) + w0 = relay.var("w0", shape=(10, 10)) + w1 = relay.var("w1", shape=(10, 10)) + + # z0 = x + w0 + x_ = compiler_begin(x, "ccompiler") + w0_ = compiler_begin(w0, "ccompiler") + z0_ = relay.add(x_, w0_) + z0 = compiler_end(z0_, "ccompiler") + + # z1 = z0 + w1 + z0__ = compiler_begin(z0, "ccompiler") + w1_ = compiler_begin(w1, "ccompiler") + z1_ = relay.add(z0__, w1_) + z1 = compiler_end(z1_, "ccompiler") + + # z2 = z0 + z1 + z2 = relay.add(z0, z1) + + f = relay.Function([x, w0, w1], z2) + mod = tvm.IRModule() + mod["main"] = f + + if merge_compiler_regions: + mod = transform.MergeCompilerRegions()(mod) + + mod = transform.PartitionGraph("mod_name")(mod) + mod = transform.InferType()(mod) + + x_data = [("x", np.random.rand(10, 10).astype("float32"))] + w_data = [("w{}".format(i), np.random.rand(10, 10).astype("float32")) for i in range(2)] + + map_inputs = OrderedDict(x_data + w_data) + output_list = generate_ref_data(mod, map_inputs) + compile_and_run( + AOTTestModel(name="my_mod", module=mod, inputs=map_inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + ) + + +@pytest.mark.parametrize("merge_compiler_regions", [False, True]) +def test_byoc_microtvm_multiple_subgraphs(merge_compiler_regions): + """This is a test case to check BYOC capabilities of AOT with multiple sub graphs""" use_unpacked_api = False interface_api = "packed" test_runner = AOT_DEFAULT_RUNNER @@ -347,6 +398,9 @@ def test_byoc_microtvm(): ann = byoc.CcompilerAnnotator() mod["main"] = ann.visit(f) + if merge_compiler_regions: + mod = transform.MergeCompilerRegions()(mod) + mod = tvm.relay.transform.PartitionGraph("mod_name")(mod) mod = tvm.relay.transform.InferType()(mod)