diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index fdbd7bd8eb2c..71f538b3110e 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -780,18 +780,8 @@ class MatchCastNode : public BindingNode { v->Visit("span", &span); } - bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const { - // NOTE: pattern can contain ShapeExpr which defines the vars - return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) && - equal(value, other->value); - } - - void SHashReduce(SHashReducer hash_reduce) const { - // NOTE: pattern can contain ShapeExpr which defines the vars - hash_reduce.DefHash(var); - hash_reduce.DefHash(struct_info); - hash_reduce(value); - } + bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const; + void SHashReduce(SHashReducer hash_reduce) const; static constexpr const char* _type_key = "relax.expr.MatchCast"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -822,13 +812,9 @@ class VarBindingNode : public BindingNode { v->Visit("span", &span); } - bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { - return equal.DefEqual(var, other->var) && equal(value, other->value); - } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(var); - hash_reduce(value); - } + bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const; + void SHashReduce(SHashReducer hash_reduce) const; + static constexpr const char* _type_key = "relax.expr.VarBinding"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 66a347f6b8ba..e0de514122b8 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -27,6 +27,7 @@ #include #include +#include #include #include "ndarray_hash_equal.h" @@ -249,15 +250,30 @@ class SEqualHandlerDefault::Impl { // in which case we can use same_as for quick checking, // or we have to run deep comparison and avoid to use same_as checks. auto run = [=]() { - if (!lhs.defined() && !rhs.defined()) return true; - if (!lhs.defined() && rhs.defined()) return false; - if (!rhs.defined() && lhs.defined()) return false; - if (lhs->type_index() != rhs->type_index()) return false; - auto it = equal_map_lhs_.find(lhs); - if (it != equal_map_lhs_.end()) { - return it->second.same_as(rhs); + std::optional early_result = [&]() -> std::optional { + if (!lhs.defined() && !rhs.defined()) return true; + if (!lhs.defined() && rhs.defined()) return false; + if (!rhs.defined() && lhs.defined()) return false; + if (lhs->type_index() != rhs->type_index()) return false; + auto it = equal_map_lhs_.find(lhs); + if (it != equal_map_lhs_.end()) { + return it->second.same_as(rhs); + } + if (equal_map_rhs_.count(rhs)) return false; + + return std::nullopt; + }(); + + if (early_result.has_value()) { + if (early_result.value()) { + return true; + } else if (IsPathTracingEnabled() && IsFailDeferralEnabled() && current_paths.defined()) { + DeferFail(current_paths.value()); + return true; + } else { + return false; + } } - if (equal_map_rhs_.count(rhs)) return false; // need to push to pending tasks in this case pending_tasks_.emplace_back(lhs, rhs, map_free_vars, current_paths); @@ -388,10 +404,7 @@ class SEqualHandlerDefault::Impl { auto& entry = task_stack_.back(); if (entry.force_fail) { - if (IsPathTracingEnabled() && !first_mismatch_->defined()) { - *first_mismatch_ = entry.current_paths; - } - return false; + return CheckResult(false, entry.lhs, entry.rhs, entry.current_paths); } if (entry.children_expanded) { @@ -530,8 +543,14 @@ bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const Obje TVM_REGISTER_GLOBAL("node.StructuralEqual") .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode, bool map_free_vars) { + // If we are asserting on failure, then the `defer_fails` option + // should be enabled, to provide better error messages. For + // example, if the number of bindings in a `relax::BindingBlock` + // differs, highlighting the first difference rather than the + // entire block. + bool defer_fails = assert_mode; Optional first_mismatch; - return SEqualHandlerDefault(assert_mode, &first_mismatch, false) + return SEqualHandlerDefault(assert_mode, &first_mismatch, defer_fails) .Equal(lhs, rhs, map_free_vars); }); diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 1bc7267af6ca..b709039e8c32 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -384,6 +384,33 @@ TVM_REGISTER_GLOBAL("relax.MatchCast") return MatchCast(var, value, struct_info, span); }); +bool MatchCastNode::SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const { + if (value->IsInstance()) { + // Recursive function definitions may reference the bound variable + // within the value being bound. In these cases, the + // `DefEqual(var, other->var)` must occur first, to ensure it is + // defined at point of use. + return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) && + equal(value, other->value); + } else { + // In all other cases, visit the bound value before the variable + // it is bound to, in order to provide better error messages. + return equal(value, other->value) && equal.DefEqual(struct_info, other->struct_info) && + equal.DefEqual(var, other->var); + } +} +void MatchCastNode::SHashReduce(SHashReducer hash_reduce) const { + if (value->IsInstance()) { + hash_reduce.DefHash(var); + hash_reduce.DefHash(struct_info); + hash_reduce(value); + } else { + hash_reduce(value); + hash_reduce.DefHash(struct_info); + hash_reduce.DefHash(var); + } +} + TVM_REGISTER_NODE_TYPE(VarBindingNode); VarBinding::VarBinding(Var var, Expr value, Span span) { @@ -398,6 +425,29 @@ TVM_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, S return VarBinding(var, value, span); }); +bool VarBindingNode::SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { + if (value->IsInstance()) { + // Recursive function definitions may reference the bound variable + // within the value being bound. In these cases, the + // `DefEqual(var, other->var)` must occur first, to ensure it is + // defined at point of use. + return equal.DefEqual(var, other->var) && equal(value, other->value); + } else { + // In all other cases, visit the bound value before the variable + // it is bound to, in order to provide better error messages. + return equal(value, other->value) && equal.DefEqual(var, other->var); + } +} +void VarBindingNode::SHashReduce(SHashReducer hash_reduce) const { + if (value->IsInstance()) { + hash_reduce.DefHash(var); + hash_reduce(value); + } else { + hash_reduce(value); + hash_reduce.DefHash(var); + } +} + TVM_REGISTER_NODE_TYPE(BindingBlockNode); BindingBlock::BindingBlock(Array bindings, Span span) { diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index 0cae5101a755..9abc53484b7f 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -14,12 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import re + import pytest import tvm from tvm import relax from tvm.ir.base import assert_structural_equal -from tvm.script.parser import relax as R +from tvm.script.parser import relax as R, tir as T def test_copy_with_new_vars(): @@ -122,6 +125,27 @@ def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): assert_structural_equal(Actual, Expected) +def test_assert_structural_equal_in_seqexpr(): + """The first mismatch is correctly identified.""" + + @R.function(private=True) + def func_1(A: R.Tensor([16, 16], "float32")): + B = R.concat([A, A]) + return B + + @R.function(private=True) + def func_2(A: R.Tensor([16, 16], "float32")): + B = R.add(A, A) + C = R.add(B, B) + return B + + with pytest.raises( + ValueError, + match=re.escape(".body.blocks[0].bindings[0].value.op"), + ): + assert_structural_equal(func_1, func_2) + + def test_structural_equal_of_call_nodes(): """relax.Call must be compared by structural equality, not reference""" @@ -145,5 +169,42 @@ def uses_two_different_objects(): tvm.ir.assert_structural_equal(uses_same_object_twice, uses_two_different_objects) +def test_structural_equal_with_recursive_lambda_function(): + """A recursive lambda function may be checked for structural equality + + Recursive function definitions may reference the bound variable + within the value being bound. In these cases, the `DefEqual(var, + other->var)` must occur first, to ensure it is defined at point of + use. + + In all other cases, checking for structural equality of the bound + value prior to the variable provides a better error message. + """ + + def define_function(): + @R.function + def func(n: R.Prim("int64")): + @R.function + def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"): + i = T.int64() + if R.prim_value(i == 0): + output = R.prim_value(T.int64(0)) + else: + remainder_relax = recursive_lambda(R.prim_value(i - 1)) + remainder_tir = T.int64() + _ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir)) + output = R.prim_value(i + remainder_tir) + return output + + return recursive_lambda(n) + + return func + + func_1 = define_function() + func_2 = define_function() + + tvm.ir.assert_structural_equal(func_1, func_2) + + if __name__ == "__main__": pytest.main([__file__])