From 086f1f940c658fde73190c4936d4a5702a4031d5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 20 Mar 2024 07:21:54 -0500 Subject: [PATCH 1/3] [IR][Relax] Improve highlighting in assert_structural_equal Prior to this commit, `tvm.ir.assert_structural_equal` would highlight an entire `relax::BindingBlock` if the number of elements in the binding block differs. This can result in the entire Relax function being highlighted, making it difficult to identify the location of the mismatch. This commit makes the following changes, to improve the error messages that occur when `tvm.ir.assert_structural_equal` raises an exception. - In `"node.StructuralEqual"`, set `defer_fails = true` when `assert_mode` is true. This highlights the first mismatch of an `Array`, rather than the entire array, in cases where the LHS and RHS have different sizes. - In the `SHashReduce` for `VarBinding` and `MatchCast`, visit the value first, and then the variable to which it is bound. This highlights the mismatched expression, rather than mismatches in the resulting struct info. - In `SEqualHandlerDefault::Impl::SEqualReduce`, defer the failure if enabled. This highlights the first mismatch, which may also have been deferred, rather than an early return a later mismatch occurs involving `NullOpt`. --- include/tvm/relax/expr.h | 12 +++++----- src/node/structural_equal.cc | 40 +++++++++++++++++++++++++------- tests/python/relax/test_utils.py | 24 +++++++++++++++++++ 3 files changed, 61 insertions(+), 15 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index fdbd7bd8eb2c..241ca16b2886 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -782,15 +782,15 @@ class MatchCastNode : public BindingNode { bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const { // NOTE: pattern can contain ShapeExpr which defines the vars - return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) && - equal(value, other->value); + return equal(value, other->value) && equal.DefEqual(struct_info, other->struct_info) && + equal.DefEqual(var, other->var); } void SHashReduce(SHashReducer hash_reduce) const { // NOTE: pattern can contain ShapeExpr which defines the vars - hash_reduce.DefHash(var); - hash_reduce.DefHash(struct_info); hash_reduce(value); + hash_reduce.DefHash(struct_info); + hash_reduce.DefHash(var); } static constexpr const char* _type_key = "relax.expr.MatchCast"; @@ -823,11 +823,11 @@ class VarBindingNode : public BindingNode { } bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { - return equal.DefEqual(var, other->var) && equal(value, other->value); + return equal(value, other->value) && equal.DefEqual(var, other->var); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce.DefHash(var); hash_reduce(value); + hash_reduce.DefHash(var); } static constexpr const char* _type_key = "relax.expr.VarBinding"; static constexpr const bool _type_has_method_sequal_reduce = true; diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 66a347f6b8ba..a940d8fdf52b 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -27,6 +27,7 @@ #include #include +#include #include #include "ndarray_hash_equal.h" @@ -249,15 +250,30 @@ class SEqualHandlerDefault::Impl { // in which case we can use same_as for quick checking, // or we have to run deep comparison and avoid to use same_as checks. auto run = [=]() { - if (!lhs.defined() && !rhs.defined()) return true; - if (!lhs.defined() && rhs.defined()) return false; - if (!rhs.defined() && lhs.defined()) return false; - if (lhs->type_index() != rhs->type_index()) return false; - auto it = equal_map_lhs_.find(lhs); - if (it != equal_map_lhs_.end()) { - return it->second.same_as(rhs); + std::optional early_result = [&]() -> std::optional { + if (!lhs.defined() && !rhs.defined()) return true; + if (!lhs.defined() && rhs.defined()) return false; + if (!rhs.defined() && lhs.defined()) return false; + if (lhs->type_index() != rhs->type_index()) return false; + auto it = equal_map_lhs_.find(lhs); + if (it != equal_map_lhs_.end()) { + return it->second.same_as(rhs); + } + if (equal_map_rhs_.count(rhs)) return false; + + return std::nullopt; + }(); + + if (early_result.has_value()) { + if (early_result.value()) { + return true; + } else if (IsPathTracingEnabled() && IsFailDeferralEnabled() && current_paths.defined()) { + DeferFail(current_paths.value()); + return true; + } else { + return false; + } } - if (equal_map_rhs_.count(rhs)) return false; // need to push to pending tasks in this case pending_tasks_.emplace_back(lhs, rhs, map_free_vars, current_paths); @@ -530,8 +546,14 @@ bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const Obje TVM_REGISTER_GLOBAL("node.StructuralEqual") .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode, bool map_free_vars) { + // If we are asserting on failure, then the `defer_fails` option + // should be enabled, to provide better error messages. For + // example, if the number of bindings in a `relax::BindingBlock` + // differs, highlighting the first difference rather than the + // entire block. + bool defer_fails = assert_mode; Optional first_mismatch; - return SEqualHandlerDefault(assert_mode, &first_mismatch, false) + return SEqualHandlerDefault(assert_mode, &first_mismatch, defer_fails) .Equal(lhs, rhs, map_free_vars); }); diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index 0cae5101a755..4754e8db7fda 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -14,6 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import re + import pytest import tvm @@ -122,6 +125,27 @@ def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): assert_structural_equal(Actual, Expected) +def test_assert_structural_equal_in_seqexpr(): + """The first mismatch is correctly identified.""" + + @R.function(private=True) + def func_1(A: R.Tensor([16, 16], "float32")): + B = R.concat([A, A]) + return B + + @R.function(private=True) + def func_2(A: R.Tensor([16, 16], "float32")): + B = R.add(A, A) + C = R.add(B, B) + return B + + with pytest.raises( + ValueError, + match=re.escape(".body.blocks[0].bindings[0].value.op"), + ): + assert_structural_equal(func_1, func_2) + + def test_structural_equal_of_call_nodes(): """relax.Call must be compared by structural equality, not reference""" From b3437d52aa7c6c127be417d214f64367d4933a72 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 20 Mar 2024 20:38:36 -0500 Subject: [PATCH 2/3] DeferFail should follow assert_mode --- src/node/structural_equal.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index a940d8fdf52b..e0de514122b8 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -404,10 +404,7 @@ class SEqualHandlerDefault::Impl { auto& entry = task_stack_.back(); if (entry.force_fail) { - if (IsPathTracingEnabled() && !first_mismatch_->defined()) { - *first_mismatch_ = entry.current_paths; - } - return false; + return CheckResult(false, entry.lhs, entry.rhs, entry.current_paths); } if (entry.children_expanded) { From 0050cbccd0f1283c3a9df8809de0c50ea78b1cb2 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 21 Mar 2024 12:18:26 -0500 Subject: [PATCH 3/3] Handle recursively defined lambda functions --- include/tvm/relax/expr.h | 24 ++++----------- src/relax/ir/expr.cc | 50 ++++++++++++++++++++++++++++++++ tests/python/relax/test_utils.py | 39 ++++++++++++++++++++++++- 3 files changed, 93 insertions(+), 20 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 241ca16b2886..71f538b3110e 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -780,18 +780,8 @@ class MatchCastNode : public BindingNode { v->Visit("span", &span); } - bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const { - // NOTE: pattern can contain ShapeExpr which defines the vars - return equal(value, other->value) && equal.DefEqual(struct_info, other->struct_info) && - equal.DefEqual(var, other->var); - } - - void SHashReduce(SHashReducer hash_reduce) const { - // NOTE: pattern can contain ShapeExpr which defines the vars - hash_reduce(value); - hash_reduce.DefHash(struct_info); - hash_reduce.DefHash(var); - } + bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const; + void SHashReduce(SHashReducer hash_reduce) const; static constexpr const char* _type_key = "relax.expr.MatchCast"; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -822,13 +812,9 @@ class VarBindingNode : public BindingNode { v->Visit("span", &span); } - bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { - return equal(value, other->value) && equal.DefEqual(var, other->var); - } - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(value); - hash_reduce.DefHash(var); - } + bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const; + void SHashReduce(SHashReducer hash_reduce) const; + static constexpr const char* _type_key = "relax.expr.VarBinding"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 1bc7267af6ca..b709039e8c32 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -384,6 +384,33 @@ TVM_REGISTER_GLOBAL("relax.MatchCast") return MatchCast(var, value, struct_info, span); }); +bool MatchCastNode::SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const { + if (value->IsInstance()) { + // Recursive function definitions may reference the bound variable + // within the value being bound. In these cases, the + // `DefEqual(var, other->var)` must occur first, to ensure it is + // defined at point of use. + return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) && + equal(value, other->value); + } else { + // In all other cases, visit the bound value before the variable + // it is bound to, in order to provide better error messages. + return equal(value, other->value) && equal.DefEqual(struct_info, other->struct_info) && + equal.DefEqual(var, other->var); + } +} +void MatchCastNode::SHashReduce(SHashReducer hash_reduce) const { + if (value->IsInstance()) { + hash_reduce.DefHash(var); + hash_reduce.DefHash(struct_info); + hash_reduce(value); + } else { + hash_reduce(value); + hash_reduce.DefHash(struct_info); + hash_reduce.DefHash(var); + } +} + TVM_REGISTER_NODE_TYPE(VarBindingNode); VarBinding::VarBinding(Var var, Expr value, Span span) { @@ -398,6 +425,29 @@ TVM_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, S return VarBinding(var, value, span); }); +bool VarBindingNode::SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const { + if (value->IsInstance()) { + // Recursive function definitions may reference the bound variable + // within the value being bound. In these cases, the + // `DefEqual(var, other->var)` must occur first, to ensure it is + // defined at point of use. + return equal.DefEqual(var, other->var) && equal(value, other->value); + } else { + // In all other cases, visit the bound value before the variable + // it is bound to, in order to provide better error messages. + return equal(value, other->value) && equal.DefEqual(var, other->var); + } +} +void VarBindingNode::SHashReduce(SHashReducer hash_reduce) const { + if (value->IsInstance()) { + hash_reduce.DefHash(var); + hash_reduce(value); + } else { + hash_reduce(value); + hash_reduce.DefHash(var); + } +} + TVM_REGISTER_NODE_TYPE(BindingBlockNode); BindingBlock::BindingBlock(Array bindings, Span span) { diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index 4754e8db7fda..9abc53484b7f 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -22,7 +22,7 @@ import tvm from tvm import relax from tvm.ir.base import assert_structural_equal -from tvm.script.parser import relax as R +from tvm.script.parser import relax as R, tir as T def test_copy_with_new_vars(): @@ -169,5 +169,42 @@ def uses_two_different_objects(): tvm.ir.assert_structural_equal(uses_same_object_twice, uses_two_different_objects) +def test_structural_equal_with_recursive_lambda_function(): + """A recursive lambda function may be checked for structural equality + + Recursive function definitions may reference the bound variable + within the value being bound. In these cases, the `DefEqual(var, + other->var)` must occur first, to ensure it is defined at point of + use. + + In all other cases, checking for structural equality of the bound + value prior to the variable provides a better error message. + """ + + def define_function(): + @R.function + def func(n: R.Prim("int64")): + @R.function + def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"): + i = T.int64() + if R.prim_value(i == 0): + output = R.prim_value(T.int64(0)) + else: + remainder_relax = recursive_lambda(R.prim_value(i - 1)) + remainder_tir = T.int64() + _ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir)) + output = R.prim_value(i + remainder_tir) + return output + + return recursive_lambda(n) + + return func + + func_1 = define_function() + func_2 = define_function() + + tvm.ir.assert_structural_equal(func_1, func_2) + + if __name__ == "__main__": pytest.main([__file__])