diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index bf58c8d5be41..15173c2c79db 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -82,9 +82,8 @@ class StorageAllocaBaseVisitor : public ExprVisitor { void VisitExpr_(const TupleNode* op) final { std::vector fields; for (Expr field : op->fields) { - auto tok = GetToken(field); - ICHECK_EQ(tok.size(), 1U); - fields.push_back(tok[0]); + auto tokens = GetToken(field); + fields.insert(fields.end(), tokens.begin(), tokens.end()); } token_map_[op] = fields; } diff --git a/tests/python/relay/test_backend_graph_runtime.py b/tests/python/relay/test_backend_graph_runtime.py index 1bd551004ad7..3c42b7b4196f 100644 --- a/tests/python/relay/test_backend_graph_runtime.py +++ b/tests/python/relay/test_backend_graph_runtime.py @@ -184,6 +184,31 @@ def unit_numpy(X, W): tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) +def test_compile_nested_tuples(): + x = relay.var("x", shape=(10,)) + x1 = x + relay.const(1.0) + x2 = x1 + relay.const(1.0) + x3 = x2 + relay.const(1.0) + x4 = x3 + relay.const(1.0) + out = relay.Tuple([x1, relay.Tuple([relay.Tuple([x2, x3]), x4])]) + func = relay.Function([x], out) + + graph, lib, _ = relay.build(tvm.IRModule.from_expr(func), "llvm") + mod = graph_runtime.create(graph, lib, ctx=tvm.cpu(0)) + + x_data = np.random.uniform(size=(10,)).astype(np.float32) + mod.set_input(x=x_data) + mod.run() + + assert mod.get_num_outputs() == 4 + + ref = x_data + 1 + for i in range(mod.get_num_outputs()): + out = mod.get_output(i).asnumpy() + tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) + ref = ref + 1 + + if __name__ == "__main__": test_plan_memory() test_with_params() @@ -191,3 +216,4 @@ def unit_numpy(X, W): test_add_op_tensor() test_add_op_broadcast() test_gru_like() + test_compile_nested_tuples()