Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ struct WorkloadHash {
/*! \brief The equality check for Workload */
struct WorkloadEqual {
bool operator()(const Workload& a, const Workload& b) const {
return a->shash == b->shash && tvm::StructuralEqual()(a->mod, b->mod);
return a->shash == b->shash && tvm::StructuralEqual(false)(a->mod, b->mod);
}
};

Expand Down
11 changes: 11 additions & 0 deletions include/tvm/node/structural_equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ class ObjectPathPair : public ObjectRef {
*/
class StructuralEqual : public BaseValueEqual {
public:
/*!
* \brief Constructor
* \param compare_ndarray_data Whether or not we compare ndarray data to determine equality.
*/
explicit StructuralEqual(bool compare_ndarray_data = true)
: compare_ndarray_data_(compare_ndarray_data) {}

// inheritate operator()
using BaseValueEqual::operator();
/*!
Expand All @@ -111,6 +118,10 @@ class StructuralEqual : public BaseValueEqual {
* \return The comparison result.
*/
TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;

private:
/*! \brief Whether or not we compare ndarray data to determine equality. */
bool compare_ndarray_data_;
};

/*!
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/node/structural_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ class BaseValueHash {
*/
class StructuralHash : public BaseValueHash {
public:
/*!
* \brief Constructor
* \param hash_ndarray_data Whether or not we hash ndarray data.
*/
explicit StructuralHash(bool hash_ndarray_data = true) : hash_ndarray_data_(hash_ndarray_data) {}

// inheritate operator()
using BaseValueHash::operator();
/*!
Expand All @@ -82,6 +88,10 @@ class StructuralHash : public BaseValueHash {
* \return The hash value.
*/
TVM_DLL size_t operator()(const ObjectRef& key) const;

private:
/*! \brief Whether or not we hash ndarray data. */
bool hash_ndarray_data_;
};

/*!
Expand Down
25 changes: 19 additions & 6 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def save_json(node):
return tvm.runtime._ffi_node_api.SaveJSON(node)


def structural_equal(lhs, rhs, map_free_vars=False):
def structural_equal(lhs, rhs, map_free_vars=False, compare_ndarray_data=True):
"""Check structural equality of lhs and rhs.

The structural equality is recursively defined in the DAG of IRNodes.
Expand Down Expand Up @@ -194,6 +194,9 @@ def structural_equal(lhs, rhs, map_free_vars=False):
Whether free variables (i.e. variables without a definition site) should be mapped
as equal to each other.

compare_ndarray_data : bool
Whether or not we compare ndarray data to determine equality.

Return
------
result : bool
Expand All @@ -206,7 +209,11 @@ def structural_equal(lhs, rhs, map_free_vars=False):
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars))
return bool(
tvm.runtime._ffi_node_api.StructuralEqual(
lhs, rhs, False, map_free_vars, compare_ndarray_data
)
)


def get_first_structural_mismatch(lhs, rhs, map_free_vars=False):
Expand Down Expand Up @@ -239,7 +246,7 @@ def get_first_structural_mismatch(lhs, rhs, map_free_vars=False):
return mismatch.lhs_path, mismatch.rhs_path


def assert_structural_equal(lhs, rhs, map_free_vars=False):
def assert_structural_equal(lhs, rhs, map_free_vars=False, compare_ndarray_data=True):
"""Assert lhs and rhs are structurally equal to each other.

Parameters
Expand All @@ -254,6 +261,9 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False):
Whether or not shall we map free vars that does
not bound to any definitions as equal to each other.

compare_ndarray_data : bool
Whether or not we compare ndarray data to determine equality.

Raises
------
ValueError : if assertion does not hold.
Expand All @@ -264,10 +274,10 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False):
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars)
tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars, compare_ndarray_data)


def structural_hash(node, map_free_vars=False):
def structural_hash(node, map_free_vars=False, hash_ndarray_data=True):
"""Compute structural hash of node

The structural hash value is recursively defined in the DAG of IRNodes.
Expand Down Expand Up @@ -297,6 +307,9 @@ def structural_hash(node, map_free_vars=False):
by the order of their occurrences. Otherwise, we will hash by
their in-memory pointer address.

hash_ndarray_data : bool
Whether or not we hash ndarray data.

Return
------
result : int
Expand All @@ -306,4 +319,4 @@ def structural_hash(node, map_free_vars=False):
--------
structrual_equal
"""
return tvm.runtime._ffi_node_api.StructuralHash(node, map_free_vars)
return tvm.runtime._ffi_node_api.StructuralHash(node, map_free_vars, hash_ndarray_data)
4 changes: 2 additions & 2 deletions rust/tvm-rt/src/object/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ external! {
#[name("ir.DebugPrint")]
pub fn debug_print(object: ObjectRef) -> CString;
#[name("node.StructuralHash")]
fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64;
fn structural_hash(object: ObjectRef, map_free_vars: bool, hash_ndarray_data : bool) -> i64;
#[name("node.StructuralEqual")]
fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> bool;
fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool, compare_ndarray_data : bool) -> bool;
}
4 changes: 2 additions & 2 deletions src/meta_schedule/database/database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace meta_schedule {

Workload::Workload(IRModule mod) {
ObjectPtr<WorkloadNode> n = runtime::make_object<WorkloadNode>();
n->shash = tvm::StructuralHash()(mod);
n->shash = tvm::StructuralHash(/*hash_ndarray_data*/ false)(mod);
n->mod = mod;
data_ = std::move(n);
}
Expand Down Expand Up @@ -61,7 +61,7 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) {
mod = Downcast<IRModule>(LoadJSON(json_mod));
}
// Verify SHash(mod) == shash
shash = tvm::StructuralHash()(mod);
shash = tvm::StructuralHash(/*hash_ndarray_data*/ false)(mod);
String recalc_shash = SHash2Str(shash);
CHECK_EQ(recalc_shash, str_shash) << "ValueError: Structural hash changed. Given: " << str_shash
<< "; Recalculated: " << recalc_shash;
Expand Down
7 changes: 4 additions & 3 deletions src/meta_schedule/database/json_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,14 @@ class JSONDatabaseNode : public DatabaseNode {

public:
bool HasWorkload(const IRModule& mod) {
return workloads2idx_.find(Workload(mod, tvm::StructuralHash()(mod))) != workloads2idx_.end();
return workloads2idx_.find(Workload(
mod, tvm::StructuralHash(/*hash_ndarray_data*/ false)(mod))) != workloads2idx_.end();
}

Workload CommitWorkload(const IRModule& mod) {
// Try to insert `mod` into `workloads_`
auto [it, inserted] =
this->workloads2idx_.emplace(Workload(mod, tvm::StructuralHash()(mod)), -1);
auto [it, inserted] = this->workloads2idx_.emplace(
Workload(mod, tvm::StructuralHash(/*hash_ndarray_data*/ false)(mod)), -1);
Workload workload = it->first;
// If `mod` is new in `workloads2idx_`, append it to the workload file
if (inserted) {
Expand Down
4 changes: 2 additions & 2 deletions src/meta_schedule/database/memory_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class MemoryDatabaseNode : public DatabaseNode {
public:
bool HasWorkload(const IRModule& mod) final {
for (const auto& workload : workloads) {
if (StructuralEqual()(workload->mod, mod)) {
if (StructuralEqual(/*compare_ndarray_data*/ false)(workload->mod, mod)) {
return true;
}
}
Expand All @@ -46,7 +46,7 @@ class MemoryDatabaseNode : public DatabaseNode {

Workload CommitWorkload(const IRModule& mod) final {
for (const auto& workload : workloads) {
if (StructuralEqual()(workload->mod, mod)) {
if (StructuralEqual(/*compare_ndarray_data*/ false)(workload->mod, mod)) {
return workload;
}
}
Expand Down
72 changes: 61 additions & 11 deletions src/node/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,39 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs,
}
}

bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs,
SEqualReducer equal, bool compare_data) {
if (lhs == rhs) return true;

auto ldt = lhs->dl_tensor.dtype;
auto rdt = rhs->dl_tensor.dtype;
ICHECK_EQ(lhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
ICHECK_EQ(rhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor";
ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor";

if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false;
for (int i = 0; i < lhs->dl_tensor.ndim; ++i) {
if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false;
}
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t data_size = runtime::GetDataSize(lhs->dl_tensor);
if (compare_data) {
return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0;
} else {
return true;
}
} else {
return false;
}
}

bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs,
const runtime::NDArray::Container* rhs,
SEqualReducer equal) {
return NDArrayEqual(lhs, rhs, equal, true);
}

/*!
* \brief A non recursive stack based SEqual handler that can remaps vars.
*
Expand All @@ -200,8 +233,11 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs,
*/
class RemapVarSEqualHandler : public SEqualReducer::Handler {
public:
explicit RemapVarSEqualHandler(bool assert_mode, Optional<ObjectPathPair>* first_mismatch)
: assert_mode_(assert_mode), first_mismatch_(first_mismatch) {}
explicit RemapVarSEqualHandler(bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
bool compare_ndarray_data = true)
: assert_mode_(assert_mode),
first_mismatch_(first_mismatch),
compare_ndarray_data_(compare_ndarray_data) {}

bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const Optional<ObjectPathPair>& current_paths) final {
Expand Down Expand Up @@ -362,19 +398,30 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
if (equal_map_rhs_.count(rhs)) return false;

// Run reduce check for free nodes.
if (!IsPathTracingEnabled()) {
return vtable_->SEqualReduce(lhs.get(), rhs.get(),
SEqualReducer(this, nullptr, map_free_vars));
SEqualReducer reducer = GetReducer(lhs, rhs, map_free_vars, current_paths);

if (auto lhs_ptr = lhs.as<runtime::NDArray::Container>(),
rhs_ptr = rhs.as<runtime::NDArray::Container>();
lhs_ptr && rhs_ptr) {
return NDArrayEqual(lhs_ptr, rhs_ptr, reducer, compare_ndarray_data_);
} else {
PathTracingData tracing_data = {current_paths.value(), lhs, rhs, first_mismatch_};
return vtable_->SEqualReduce(lhs.get(), rhs.get(),
SEqualReducer(this, &tracing_data, map_free_vars));
return vtable_->SEqualReduce(lhs.get(), rhs.get(), reducer);
}
};
return CheckResult(compute(), lhs, rhs, current_paths);
}

private:
SEqualReducer GetReducer(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const Optional<ObjectPathPair>& current_paths) {
if (!IsPathTracingEnabled()) {
return SEqualReducer(this, nullptr, map_free_vars);
} else {
PathTracingData tracing_data = {current_paths.value(), lhs, rhs, first_mismatch_};
return SEqualReducer(this, &tracing_data, map_free_vars);
}
}

/*! \brief Pending reduce tasks. */
struct Task {
/*! \brief The lhs operand to be compared. */
Expand Down Expand Up @@ -423,12 +470,15 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_lhs_;
// map from rhs to lhs
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_rhs_;
// Whether or not compare ndarray raw data
bool compare_ndarray_data_;
};

TVM_REGISTER_GLOBAL("node.StructuralEqual")
.set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode,
bool map_free_vars) {
return RemapVarSEqualHandler(assert_mode, nullptr).Equal(lhs, rhs, map_free_vars);
bool map_free_vars, bool compare_ndarray_data) {
return RemapVarSEqualHandler(assert_mode, nullptr, compare_ndarray_data)
.Equal(lhs, rhs, map_free_vars);
});

TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch")
Expand All @@ -440,7 +490,7 @@ TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch")
});

bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
return RemapVarSEqualHandler(false, nullptr).Equal(lhs, rhs, false);
return RemapVarSEqualHandler(false, nullptr, compare_ndarray_data_).Equal(lhs, rhs, false);
}

} // namespace tvm
Loading