diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 0ed0e3df3056..e5626d338f6c 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -69,7 +69,7 @@ struct TypeAlphaEq : TypeVisitor { void VisitType_(const IncompleteTypeNode* bt1, const Type& t2) final { if (const IncompleteTypeNode* bt2 = t2.as()) { - equal = equal && bt1 == bt2; + equal = equal && bt1->kind == bt2->kind; return; } else { equal = false; diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index dd722399dac4..f64afb51d834 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -20,11 +20,8 @@ def test_incomplete_type_alpha_equal(): t2 = relay.IncompleteType(relay.Kind.Type) t3 = relay.IncompleteType(relay.Kind.Type) - # only equal when there is pointer equality - assert t2 == t2 - assert t1 == t1 assert t1 != t2 - assert t2 != t3 + assert t2 == t3 def test_type_param_alpha_equal():