diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index fa488a38ce0a..f0cb1bb6af96 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -94,7 +94,7 @@ struct WorkloadHash { /*! \brief The equality check for Workload */ struct WorkloadEqual { bool operator()(const Workload& a, const Workload& b) const { - return a->shash == b->shash && tvm::StructuralEqual()(a->mod, b->mod); + return a->shash == b->shash && tvm::StructuralEqual(false)(a->mod, b->mod); } }; diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index b51021fe4076..7c9081e96baa 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -102,6 +102,13 @@ class ObjectPathPair : public ObjectRef { */ class StructuralEqual : public BaseValueEqual { public: + /*! + * \brief Constructor + * \param compare_ndarray_data Whether or not we compare ndarray data to determine equality. + */ + explicit StructuralEqual(bool compare_ndarray_data = true) + : compare_ndarray_data_(compare_ndarray_data) {} + // inheritate operator() using BaseValueEqual::operator(); /*! @@ -111,6 +118,10 @@ class StructuralEqual : public BaseValueEqual { * \return The comparison result. */ TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; + + private: + /*! \brief Whether or not we compare ndarray data to determine equality. */ + bool compare_ndarray_data_; }; /*! diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index a30a2c59d0d1..bd106f85cab2 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -74,6 +74,12 @@ class BaseValueHash { */ class StructuralHash : public BaseValueHash { public: + /*! + * \brief Constructor + * \param hash_ndarray_data Whether or not we hash ndarray data. + */ + explicit StructuralHash(bool hash_ndarray_data = true) : hash_ndarray_data_(hash_ndarray_data) {} + // inheritate operator() using BaseValueHash::operator(); /*! @@ -82,6 +88,10 @@ class StructuralHash : public BaseValueHash { * \return The hash value. */ TVM_DLL size_t operator()(const ObjectRef& key) const; + + private: + /*! \brief Whether or not we hash ndarray data. */ + bool hash_ndarray_data_; }; /*! diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index c6b30d38edac..b86721f3a992 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -157,7 +157,7 @@ def save_json(node): return tvm.runtime._ffi_node_api.SaveJSON(node) -def structural_equal(lhs, rhs, map_free_vars=False): +def structural_equal(lhs, rhs, map_free_vars=False, compare_ndarray_data=True): """Check structural equality of lhs and rhs. The structural equality is recursively defined in the DAG of IRNodes. @@ -194,6 +194,9 @@ def structural_equal(lhs, rhs, map_free_vars=False): Whether free variables (i.e. variables without a definition site) should be mapped as equal to each other. + compare_ndarray_data : bool + Whether or not we compare ndarray data to determine equality. + Return ------ result : bool @@ -206,7 +209,11 @@ def structural_equal(lhs, rhs, map_free_vars=False): """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) - return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars)) + return bool( + tvm.runtime._ffi_node_api.StructuralEqual( + lhs, rhs, False, map_free_vars, compare_ndarray_data + ) + ) def get_first_structural_mismatch(lhs, rhs, map_free_vars=False): @@ -239,7 +246,7 @@ def get_first_structural_mismatch(lhs, rhs, map_free_vars=False): return mismatch.lhs_path, mismatch.rhs_path -def assert_structural_equal(lhs, rhs, map_free_vars=False): +def assert_structural_equal(lhs, rhs, map_free_vars=False, compare_ndarray_data=True): """Assert lhs and rhs are structurally equal to each other. Parameters @@ -254,6 +261,9 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False): Whether or not shall we map free vars that does not bound to any definitions as equal to each other. + compare_ndarray_data : bool + Whether or not we compare ndarray data to determine equality. + Raises ------ ValueError : if assertion does not hold. @@ -264,10 +274,10 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False): """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) - tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars) + tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars, compare_ndarray_data) -def structural_hash(node, map_free_vars=False): +def structural_hash(node, map_free_vars=False, hash_ndarray_data=True): """Compute structural hash of node The structural hash value is recursively defined in the DAG of IRNodes. @@ -297,6 +307,9 @@ def structural_hash(node, map_free_vars=False): by the order of their occurrences. Otherwise, we will hash by their in-memory pointer address. + hash_ndarray_data : bool + Whether or not we hash ndarray data. + Return ------ result : int @@ -306,4 +319,4 @@ def structural_hash(node, map_free_vars=False): -------- structrual_equal """ - return tvm.runtime._ffi_node_api.StructuralHash(node, map_free_vars) + return tvm.runtime._ffi_node_api.StructuralHash(node, map_free_vars, hash_ndarray_data) diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index f5832fcb3ab8..4be818e15ae6 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -104,7 +104,7 @@ external! { #[name("ir.DebugPrint")] pub fn debug_print(object: ObjectRef) -> CString; #[name("node.StructuralHash")] - fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64; + fn structural_hash(object: ObjectRef, map_free_vars: bool, hash_ndarray_data : bool) -> i64; #[name("node.StructuralEqual")] - fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> bool; + fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool, compare_ndarray_data : bool) -> bool; } diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index d082ff7a3901..d79caf692269 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -25,7 +25,7 @@ namespace meta_schedule { Workload::Workload(IRModule mod) { ObjectPtr n = runtime::make_object(); - n->shash = tvm::StructuralHash()(mod); + n->shash = tvm::StructuralHash(/*hash_ndarray_data*/ false)(mod); n->mod = mod; data_ = std::move(n); } @@ -61,7 +61,7 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) { mod = Downcast(LoadJSON(json_mod)); } // Verify SHash(mod) == shash - shash = tvm::StructuralHash()(mod); + shash = tvm::StructuralHash(/*hash_ndarray_data*/ false)(mod); String recalc_shash = SHash2Str(shash); CHECK_EQ(recalc_shash, str_shash) << "ValueError: Structural hash changed. Given: " << str_shash << "; Recalculated: " << recalc_shash; diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 91b96c82479f..0262cfcf715a 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -88,13 +88,14 @@ class JSONDatabaseNode : public DatabaseNode { public: bool HasWorkload(const IRModule& mod) { - return workloads2idx_.find(Workload(mod, tvm::StructuralHash()(mod))) != workloads2idx_.end(); + return workloads2idx_.find(Workload( + mod, tvm::StructuralHash(/*hash_ndarray_data*/ false)(mod))) != workloads2idx_.end(); } Workload CommitWorkload(const IRModule& mod) { // Try to insert `mod` into `workloads_` - auto [it, inserted] = - this->workloads2idx_.emplace(Workload(mod, tvm::StructuralHash()(mod)), -1); + auto [it, inserted] = this->workloads2idx_.emplace( + Workload(mod, tvm::StructuralHash(/*hash_ndarray_data*/ false)(mod)), -1); Workload workload = it->first; // If `mod` is new in `workloads2idx_`, append it to the workload file if (inserted) { diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index b6c635555152..043e9a0c09a1 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -37,7 +37,7 @@ class MemoryDatabaseNode : public DatabaseNode { public: bool HasWorkload(const IRModule& mod) final { for (const auto& workload : workloads) { - if (StructuralEqual()(workload->mod, mod)) { + if (StructuralEqual(/*compare_ndarray_data*/ false)(workload->mod, mod)) { return true; } } @@ -46,7 +46,7 @@ class MemoryDatabaseNode : public DatabaseNode { Workload CommitWorkload(const IRModule& mod) final { for (const auto& workload : workloads) { - if (StructuralEqual()(workload->mod, mod)) { + if (StructuralEqual(/*compare_ndarray_data*/ false)(workload->mod, mod)) { return workload; } } diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 01874c0536ae..32a9dab2a108 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -189,6 +189,39 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, } } +bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs, + SEqualReducer equal, bool compare_data) { + if (lhs == rhs) return true; + + auto ldt = lhs->dl_tensor.dtype; + auto rdt = rhs->dl_tensor.dtype; + ICHECK_EQ(lhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK_EQ(rhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor"; + ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor"; + + if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; + for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { + if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false; + } + if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { + size_t data_size = runtime::GetDataSize(lhs->dl_tensor); + if (compare_data) { + return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0; + } else { + return true; + } + } else { + return false; + } +} + +bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs, + const runtime::NDArray::Container* rhs, + SEqualReducer equal) { + return NDArrayEqual(lhs, rhs, equal, true); +} + /*! * \brief A non recursive stack based SEqual handler that can remaps vars. * @@ -200,8 +233,11 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, */ class RemapVarSEqualHandler : public SEqualReducer::Handler { public: - explicit RemapVarSEqualHandler(bool assert_mode, Optional* first_mismatch) - : assert_mode_(assert_mode), first_mismatch_(first_mismatch) {} + explicit RemapVarSEqualHandler(bool assert_mode, Optional* first_mismatch, + bool compare_ndarray_data = true) + : assert_mode_(assert_mode), + first_mismatch_(first_mismatch), + compare_ndarray_data_(compare_ndarray_data) {} bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, const Optional& current_paths) final { @@ -362,19 +398,30 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { if (equal_map_rhs_.count(rhs)) return false; // Run reduce check for free nodes. - if (!IsPathTracingEnabled()) { - return vtable_->SEqualReduce(lhs.get(), rhs.get(), - SEqualReducer(this, nullptr, map_free_vars)); + SEqualReducer reducer = GetReducer(lhs, rhs, map_free_vars, current_paths); + + if (auto lhs_ptr = lhs.as(), + rhs_ptr = rhs.as(); + lhs_ptr && rhs_ptr) { + return NDArrayEqual(lhs_ptr, rhs_ptr, reducer, compare_ndarray_data_); } else { - PathTracingData tracing_data = {current_paths.value(), lhs, rhs, first_mismatch_}; - return vtable_->SEqualReduce(lhs.get(), rhs.get(), - SEqualReducer(this, &tracing_data, map_free_vars)); + return vtable_->SEqualReduce(lhs.get(), rhs.get(), reducer); } }; return CheckResult(compute(), lhs, rhs, current_paths); } private: + SEqualReducer GetReducer(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, + const Optional& current_paths) { + if (!IsPathTracingEnabled()) { + return SEqualReducer(this, nullptr, map_free_vars); + } else { + PathTracingData tracing_data = {current_paths.value(), lhs, rhs, first_mismatch_}; + return SEqualReducer(this, &tracing_data, map_free_vars); + } + } + /*! \brief Pending reduce tasks. */ struct Task { /*! \brief The lhs operand to be compared. */ @@ -423,12 +470,15 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { std::unordered_map equal_map_lhs_; // map from rhs to lhs std::unordered_map equal_map_rhs_; + // Whether or not compare ndarray raw data + bool compare_ndarray_data_; }; TVM_REGISTER_GLOBAL("node.StructuralEqual") .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode, - bool map_free_vars) { - return RemapVarSEqualHandler(assert_mode, nullptr).Equal(lhs, rhs, map_free_vars); + bool map_free_vars, bool compare_ndarray_data) { + return RemapVarSEqualHandler(assert_mode, nullptr, compare_ndarray_data) + .Equal(lhs, rhs, map_free_vars); }); TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch") @@ -440,7 +490,7 @@ TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch") }); bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { - return RemapVarSEqualHandler(false, nullptr).Equal(lhs, rhs, false); + return RemapVarSEqualHandler(false, nullptr, compare_ndarray_data_).Equal(lhs, rhs, false); } } // namespace tvm diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index b40b1751fb78..86e8e96a87a5 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -48,6 +48,22 @@ void ReflectionVTable::SHashReduce(const Object* self, SHashReducer reducer) con fshash_reduce_[tindex](self, reducer); } +void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce, + bool hash_data) { + ICHECK_EQ(arr->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK(runtime::IsContiguous(arr->dl_tensor)) << "Can only hash contiguous tensor"; + (*hash_reduce)(runtime::DataType(arr->dl_tensor.dtype)); + (*hash_reduce)(arr->dl_tensor.ndim); + for (int i = 0; i < arr->dl_tensor.ndim; ++i) { + (*hash_reduce)(arr->dl_tensor.shape[i]); + } + if (hash_data) { + (*hash_reduce) + ->SHashReduceHashedValue(runtime::String::HashBytes( + static_cast(arr->dl_tensor.data), runtime::GetDataSize(arr->dl_tensor))); + } +} + // Hash handler that handles free vars // by assigning an unique counter in the order of their ocurrence. // @@ -57,6 +73,9 @@ void ReflectionVTable::SHashReduce(const Object* self, SHashReducer reducer) con class VarCountingSHashHandler : public SHashReducer::Handler { public: + explicit VarCountingSHashHandler(bool hash_ndarray_data = true) + : hash_ndarray_data_(hash_ndarray_data) {} + /*! \brief Pending reduce tasks. */ struct Task { /*! @@ -234,7 +253,12 @@ class VarCountingSHashHandler : public SHashReducer::Handler { // The default equal as registered in the structural equal vtable. void DispatchSHash(const ObjectRef& object, bool map_free_vars) { ICHECK(object.defined()); - vtable_->SHashReduce(object.get(), SHashReducer(this, map_free_vars)); + SHashReducer hash_reduce(this, map_free_vars); + if (auto ndarray = object.as()) { + NDArrayHash(ndarray, &hash_reduce, hash_ndarray_data_); + } else { + vtable_->SHashReduce(object.get(), hash_reduce); + } } private: @@ -254,16 +278,19 @@ class VarCountingSHashHandler : public SHashReducer::Handler { ReflectionVTable* vtable_ = ReflectionVTable::Global(); // map from lhs to rhs std::unordered_map hash_memo_; + // Whether or not hash ndarray raw data + bool hash_ndarray_data_; }; TVM_REGISTER_GLOBAL("node.StructuralHash") - .set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t { - size_t hashed_value = VarCountingSHashHandler().Hash(object, map_free_vars); + .set_body_typed([](const ObjectRef& object, bool map_free_vars, + bool hash_ndarray_data) -> int64_t { + size_t hashed_value = VarCountingSHashHandler(hash_ndarray_data).Hash(object, map_free_vars); return static_cast(hashed_value); }); size_t StructuralHash::operator()(const ObjectRef& object) const { - return VarCountingSHashHandler().Hash(object, false); + return VarCountingSHashHandler(hash_ndarray_data_).Hash(object, false); } // SEQualReduce traits for runtime containers. @@ -331,39 +358,7 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait); void NDArrayContainerTrait::SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce) { - ICHECK_EQ(key->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK(runtime::IsContiguous(key->dl_tensor)) << "Can only hash contiguous tensor"; - hash_reduce(runtime::DataType(key->dl_tensor.dtype)); - hash_reduce(key->dl_tensor.ndim); - for (int i = 0; i < key->dl_tensor.ndim; ++i) { - hash_reduce(key->dl_tensor.shape[i]); - } - hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes( - static_cast(key->dl_tensor.data), runtime::GetDataSize(key->dl_tensor))); -} - -bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs, - const runtime::NDArray::Container* rhs, - SEqualReducer equal) { - if (lhs == rhs) return true; - - auto ldt = lhs->dl_tensor.dtype; - auto rdt = rhs->dl_tensor.dtype; - ICHECK_EQ(lhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK_EQ(rhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor"; - ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor"; - - if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; - for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { - if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false; - } - if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { - size_t data_size = runtime::GetDataSize(lhs->dl_tensor); - return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0; - } else { - return false; - } + NDArrayHash(key, &hash_reduce, /*hash_data*/ true); } TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 213841c621de..74577c973eb0 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -49,7 +49,7 @@ Array ExtractTask(IRModule mod, Target target, if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) { return; } - tec::CCacheKey cache_key(relay_func, target); + tec::CCacheKey cache_key(relay_func, target, /*ignore_ndarray_data*/ true); auto it = cache.find(cache_key); if (it != cache.end()) { it->second->weight += 1; diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 8fa8610c0fca..a7efe969e11b 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -918,7 +918,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } else { // Cases 1 and 2: lower the primitive function for the desired target, possibly using external // codegen. - CCacheKey key(Downcast(primitive_func), target, + CCacheKey key(Downcast(primitive_func), target, /*ignore_ndarray_data*/ false, GetVirtualDevice(GetRef(call_node))); CachedFunc cfunc = compiler_->Lower(key); ICHECK(cfunc.defined()); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 1d7566ebe2bd..496e26102fd0 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -69,10 +69,11 @@ LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation im data_ = std::move(n); } -CCacheKey::CCacheKey(Function source_func, Target target, VirtualDevice vd) { +CCacheKey::CCacheKey(Function source_func, Target target, bool ignore_ndarray_data, VirtualDevice vd) { auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); + n->ignore_ndarray_data = ignore_ndarray_data; n->virtual_device = std::move(vd); data_ = std::move(n); } @@ -365,8 +366,9 @@ class ScheduleBuilder : public ExprVisitor { constants.push_back(const_node->data); } if (Optional f = tir_converter(te_args, constants)) { + IRModule query_mod = backend::PrimFuncToIRModule(f.value()); if (Optional opt_record = database_.value()->QueryTuningRecord( - /*mod=*/backend::PrimFuncToIRModule(f.value()), + /*mod=*/query_mod, /*target=*/target_, /*workload_name=*/prim_fn_var->name_hint)) { static InstructionKind kind_transform_layout = InstructionKind::Get("TransformLayout"); @@ -377,7 +379,7 @@ class ScheduleBuilder : public ExprVisitor { MetaScheduleLayoutRewriter::LayoutQueuePush(Downcast(inst->attrs[2])); } } - Schedule sch = Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0, + Schedule sch = Schedule::Traced(query_mod, /*seed=*/-1, /*debug_mask=*/0, tir::ScheduleErrorRenderLevel::kDetail); record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false); IRModule mod = sch->mod(); diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 95c5bc974181..83976341010d 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -86,10 +86,13 @@ class CCacheKeyNode : public Object { Target target; /*! \brief The virtual device constrains.*/ VirtualDevice virtual_device; + /*! \brief Whether or not we ignore ndarray raw data when comparing and hashing them. */ + bool ignore_ndarray_data; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("source_func", &source_func); v->Visit("target", &target); + v->Visit("ignore_ndarray_data", &ignore_ndarray_data); v->Visit("virtual_device", &virtual_device); } /*! \return The hash value of CCacheKey. */ @@ -121,8 +124,10 @@ class CCacheKey : public ObjectRef { * \brief The constructor * \param source_func The source function. * \param target The target device. + * \param ignore_ndarray_data Whether or not we ignore ndarray raw data when comparing and hashing + * them */ - TVM_DLL CCacheKey(Function source_func, Target target, + TVM_DLL CCacheKey(Function source_func, Target target, bool ignore_ndarray_data = false, VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained()); const CCacheKeyNode* operator->() const { return static_cast(get()); } @@ -238,7 +243,7 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, inline size_t CCacheKeyNode::Hash() const { if (hash_ != 0) return hash_; // do structral hash, avoid 0. - hash_ = tvm::StructuralHash()(this->source_func); + hash_ = tvm::StructuralHash(!ignore_ndarray_data)(this->source_func); hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); if (hash_ == 0) hash_ = 1; return hash_; @@ -248,7 +253,7 @@ inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { if (Hash() != other->Hash()) return false; return this->target->str() == other->target->str() && this->virtual_device == other->virtual_device && - tvm::StructuralEqual()(this->source_func, other->source_func); + tvm::StructuralEqual(!ignore_ndarray_data)(this->source_func, other->source_func); } } // namespace tec diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 42fec9e27af2..1f37e52daa9d 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -111,7 +111,7 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { // Compare the objects for structural equality static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); ICHECK(structural_equal) << "node.StructuralEqual is not registered."; - if ((*structural_equal)(lhs, GetRef(rhs.ptr()), false, true)) { + if ((*structural_equal)(lhs, GetRef(rhs.ptr()), false, true, true)) { return true; } } @@ -815,7 +815,7 @@ Expr PatternRewriter::Rewrite(const Array& callbacks, const E VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post); count++; } - equal = (*structural_equal)(last, post, false, true); + equal = (*structural_equal)(last, post, false, true, true); } while (!equal && count < 100 && !callback_->rewrite_once); if (count >= 100) { LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?"; diff --git a/tests/python/relay/test_ir_structural_equal_hash.py b/tests/python/relay/test_ir_structural_equal_hash.py index a808259d26af..3fe85f5bcd0b 100644 --- a/tests/python/relay/test_ir_structural_equal_hash.py +++ b/tests/python/relay/test_ir_structural_equal_hash.py @@ -796,5 +796,54 @@ def func3(): assert not tvm.ir.structural_equal(func1(), func3()) +def test_ignore_ndarray_data(): + a = relay.const([1.0, 2.0], dtype="float32") + b = relay.const([2.0, 1.0], dtype="float32") + assert not tvm.ir.structural_equal(a, b) + assert not tvm.ir.structural_hash(a) == tvm.ir.structural_hash(b) + assert tvm.ir.structural_equal(a, b, compare_ndarray_data=False) + assert tvm.ir.structural_hash(a, hash_ndarray_data=False) == tvm.ir.structural_hash( + b, hash_ndarray_data=False + ) + + def get_conv2d_nchw( + d_shape, + w_shape, + ): + data = relay.var("data", shape=d_shape, dtype="float16") + weight = relay.var("weight", shape=w_shape, dtype="float16") + out_channel = w_shape[0] + return relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=(1, 1), + strides=(1, 1), + out_dtype="float16", + ) + + data_shape = (1, 128, 28, 28) + weight_shape = (128, 128, 1, 1) + + conv2d = get_conv2d_nchw(data_shape, weight_shape) + mod1 = tvm.IRModule.from_expr(conv2d) + mod2 = tvm.IRModule.from_expr(conv2d) + + params = {"weight": np.random.randn(*weight_shape).astype("float16")} + mod1 = relay.build_module.bind_params_by_name(mod1["main"], params) + + params = {"weight": np.random.randn(*weight_shape).astype("float16")} + mod2 = relay.build_module.bind_params_by_name(mod2["main"], params) + + assert not tvm.ir.structural_equal(mod1, mod2) + assert not tvm.ir.structural_hash(mod1) == tvm.ir.structural_hash(mod2) + + assert tvm.ir.structural_equal(mod1, mod2, compare_ndarray_data=False) + assert tvm.ir.structural_hash(mod1, hash_ndarray_data=False) == tvm.ir.structural_hash( + mod2, hash_ndarray_data=False + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index 366a2e4887ed..dc8ff21babc2 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -339,5 +339,23 @@ def test_extract_task_arm_conv2d_nchwc(): assert list(out_type.shape) == [1, 8, 130, 130, 4] +@requires_torch +def test_link_params(): + target = "llvm --num-cores=10" + mod, params, _ = get_network(name="resnet_50", input_shape=[1, 3, 224, 224]) + + pass_config = { + "relay.FuseOps.link_params": True, + "relay.backend.use_meta_schedule": True, + "relay.backend.tir_converter": "default", + } + + extracted_tasks = ms.extract_task_from_relay(mod, target, params, pass_config=pass_config) + + conv2d_tasks = list(filter(lambda task: "conv2d" in task.task_name, extracted_tasks)) + + assert len(conv2d_tasks) == 24 + + if __name__ == "__main__": tvm.testing.main()