diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index c1f6cdc63974..c0f1db97b538 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -366,7 +366,7 @@ class TypeInferencer::Resolver : public ExprMutator { } Expr VisitExpr_(const GlobalVarNode* op) final { - return AttachCheckedType(op); + return GetRef(op); } Expr VisitExpr_(const OpNode* op) final { diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 8f92fc0f5192..b1823004022c 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -123,6 +123,16 @@ def f(x) { assert relay.ir_pass.infer_type(f).checked_type == relay.FuncType([a], a) assert relay.ir_pass.infer_type(fx).checked_type == a +def test_global_var_cow_issue(): + env = relay.env.Environment({}) + gv = relay.GlobalVar("foo") + x = relay.var('x', shape=[]) + func = relay.Function([x], relay.Call(gv, [x]), relay.TensorType([], 'float32')) + env[gv] = func + # They should both point to the same global variable if global variables are + # stable across type checking. + assert gv == func.body.op + if __name__ == "__main__": test_free_expr() test_dual_op() @@ -134,3 +144,4 @@ def f(x) { test_free_expr() test_type_args() test_self_reference() + test_global_var_cow_issue() \ No newline at end of file