Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -780,18 +780,8 @@ class MatchCastNode : public BindingNode {
v->Visit("span", &span);
}

bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const {
// NOTE: pattern can contain ShapeExpr which defines the vars
return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) &&
equal(value, other->value);
}

void SHashReduce(SHashReducer hash_reduce) const {
// NOTE: pattern can contain ShapeExpr which defines the vars
hash_reduce.DefHash(var);
hash_reduce.DefHash(struct_info);
hash_reduce(value);
}
bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const;
void SHashReduce(SHashReducer hash_reduce) const;

static constexpr const char* _type_key = "relax.expr.MatchCast";
static constexpr const bool _type_has_method_sequal_reduce = true;
Expand Down Expand Up @@ -822,13 +812,9 @@ class VarBindingNode : public BindingNode {
v->Visit("span", &span);
}

bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const {
return equal.DefEqual(var, other->var) && equal(value, other->value);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(var);
hash_reduce(value);
}
bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const;
void SHashReduce(SHashReducer hash_reduce) const;

static constexpr const char* _type_key = "relax.expr.VarBinding";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand Down
45 changes: 32 additions & 13 deletions src/node/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/node/structural_equal.h>
#include <tvm/runtime/registry.h>

#include <optional>
#include <unordered_map>

#include "ndarray_hash_equal.h"
Expand Down Expand Up @@ -249,15 +250,30 @@ class SEqualHandlerDefault::Impl {
// in which case we can use same_as for quick checking,
// or we have to run deep comparison and avoid to use same_as checks.
auto run = [=]() {
if (!lhs.defined() && !rhs.defined()) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->type_index() != rhs->type_index()) return false;
auto it = equal_map_lhs_.find(lhs);
if (it != equal_map_lhs_.end()) {
return it->second.same_as(rhs);
std::optional<bool> early_result = [&]() -> std::optional<bool> {
if (!lhs.defined() && !rhs.defined()) return true;
if (!lhs.defined() && rhs.defined()) return false;
if (!rhs.defined() && lhs.defined()) return false;
if (lhs->type_index() != rhs->type_index()) return false;
auto it = equal_map_lhs_.find(lhs);
if (it != equal_map_lhs_.end()) {
return it->second.same_as(rhs);
}
if (equal_map_rhs_.count(rhs)) return false;

return std::nullopt;
}();

if (early_result.has_value()) {
if (early_result.value()) {
return true;
} else if (IsPathTracingEnabled() && IsFailDeferralEnabled() && current_paths.defined()) {
DeferFail(current_paths.value());
return true;
} else {
return false;
}
}
if (equal_map_rhs_.count(rhs)) return false;

// need to push to pending tasks in this case
pending_tasks_.emplace_back(lhs, rhs, map_free_vars, current_paths);
Expand Down Expand Up @@ -388,10 +404,7 @@ class SEqualHandlerDefault::Impl {
auto& entry = task_stack_.back();

if (entry.force_fail) {
if (IsPathTracingEnabled() && !first_mismatch_->defined()) {
*first_mismatch_ = entry.current_paths;
}
return false;
return CheckResult(false, entry.lhs, entry.rhs, entry.current_paths);
}

if (entry.children_expanded) {
Expand Down Expand Up @@ -530,8 +543,14 @@ bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const Obje
TVM_REGISTER_GLOBAL("node.StructuralEqual")
.set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode,
bool map_free_vars) {
// If we are asserting on failure, then the `defer_fails` option
// should be enabled, to provide better error messages. For
// example, if the number of bindings in a `relax::BindingBlock`
// differs, highlighting the first difference rather than the
// entire block.
bool defer_fails = assert_mode;
Optional<ObjectPathPair> first_mismatch;
return SEqualHandlerDefault(assert_mode, &first_mismatch, false)
return SEqualHandlerDefault(assert_mode, &first_mismatch, defer_fails)
.Equal(lhs, rhs, map_free_vars);
});

Expand Down
50 changes: 50 additions & 0 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,33 @@ TVM_REGISTER_GLOBAL("relax.MatchCast")
return MatchCast(var, value, struct_info, span);
});

bool MatchCastNode::SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const {
if (value->IsInstance<FunctionNode>()) {
// Recursive function definitions may reference the bound variable
// within the value being bound. In these cases, the
// `DefEqual(var, other->var)` must occur first, to ensure it is
// defined at point of use.
return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, other->struct_info) &&
Comment on lines +389 to +393
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per the above discussion, this sounds correct and indeed, local functions are the only time a local var can be used recursively in Relax.

equal(value, other->value);
} else {
// In all other cases, visit the bound value before the variable
// it is bound to, in order to provide better error messages.
return equal(value, other->value) && equal.DefEqual(struct_info, other->struct_info) &&
equal.DefEqual(var, other->var);
}
}
void MatchCastNode::SHashReduce(SHashReducer hash_reduce) const {
if (value->IsInstance<FunctionNode>()) {
hash_reduce.DefHash(var);
hash_reduce.DefHash(struct_info);
hash_reduce(value);
} else {
hash_reduce(value);
hash_reduce.DefHash(struct_info);
hash_reduce.DefHash(var);
}
}

TVM_REGISTER_NODE_TYPE(VarBindingNode);

VarBinding::VarBinding(Var var, Expr value, Span span) {
Expand All @@ -398,6 +425,29 @@ TVM_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, S
return VarBinding(var, value, span);
});

bool VarBindingNode::SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const {
if (value->IsInstance<FunctionNode>()) {
// Recursive function definitions may reference the bound variable
// within the value being bound. In these cases, the
// `DefEqual(var, other->var)` must occur first, to ensure it is
// defined at point of use.
return equal.DefEqual(var, other->var) && equal(value, other->value);
} else {
// In all other cases, visit the bound value before the variable
// it is bound to, in order to provide better error messages.
return equal(value, other->value) && equal.DefEqual(var, other->var);
}
}
void VarBindingNode::SHashReduce(SHashReducer hash_reduce) const {
if (value->IsInstance<FunctionNode>()) {
hash_reduce.DefHash(var);
hash_reduce(value);
} else {
hash_reduce(value);
hash_reduce.DefHash(var);
}
}

TVM_REGISTER_NODE_TYPE(BindingBlockNode);

BindingBlock::BindingBlock(Array<Binding> bindings, Span span) {
Expand Down
63 changes: 62 additions & 1 deletion tests/python/relax/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import re

import pytest

import tvm
from tvm import relax
from tvm.ir.base import assert_structural_equal
from tvm.script.parser import relax as R
from tvm.script.parser import relax as R, tir as T


def test_copy_with_new_vars():
Expand Down Expand Up @@ -122,6 +125,27 @@ def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"):
assert_structural_equal(Actual, Expected)


def test_assert_structural_equal_in_seqexpr():
"""The first mismatch is correctly identified."""

@R.function(private=True)
def func_1(A: R.Tensor([16, 16], "float32")):
B = R.concat([A, A])
return B

@R.function(private=True)
def func_2(A: R.Tensor([16, 16], "float32")):
B = R.add(A, A)
C = R.add(B, B)
return B

with pytest.raises(
ValueError,
match=re.escape("<root>.body.blocks[0].bindings[0].value.op"),
):
assert_structural_equal(func_1, func_2)


def test_structural_equal_of_call_nodes():
"""relax.Call must be compared by structural equality, not reference"""

Expand All @@ -145,5 +169,42 @@ def uses_two_different_objects():
tvm.ir.assert_structural_equal(uses_same_object_twice, uses_two_different_objects)


def test_structural_equal_with_recursive_lambda_function():
"""A recursive lambda function may be checked for structural equality

Recursive function definitions may reference the bound variable
within the value being bound. In these cases, the `DefEqual(var,
other->var)` must occur first, to ensure it is defined at point of
use.

In all other cases, checking for structural equality of the bound
value prior to the variable provides a better error message.
"""

def define_function():
@R.function
def func(n: R.Prim("int64")):
@R.function
def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"):
i = T.int64()
if R.prim_value(i == 0):
output = R.prim_value(T.int64(0))
else:
remainder_relax = recursive_lambda(R.prim_value(i - 1))
remainder_tir = T.int64()
_ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir))
output = R.prim_value(i + remainder_tir)
return output

return recursive_lambda(n)

return func

func_1 = define_function()
func_2 = define_function()

tvm.ir.assert_structural_equal(func_1, func_2)


if __name__ == "__main__":
pytest.main([__file__])