From f4b6aca64467d967e829dd5660ca9effdba99ff4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 3 Sep 2022 12:16:27 +0900 Subject: [PATCH 1/8] [Node] Allow ignoring NDArray raw data in StructuralEqual and StructuralHash --- src/node/structural_equal.cc | 56 ++++++++++++++++++++++++++++++++---- src/node/structural_hash.cc | 56 ++++++++++++++---------------------- 2 files changed, 72 insertions(+), 40 deletions(-) diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 01874c0536ae..03093847cff0 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -189,6 +189,39 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, } } +bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs, + SEqualReducer equal, bool compare_data) { + if (lhs == rhs) return true; + + auto ldt = lhs->dl_tensor.dtype; + auto rdt = rhs->dl_tensor.dtype; + ICHECK_EQ(lhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK_EQ(rhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor"; + ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor"; + + if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; + for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { + if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false; + } + if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { + size_t data_size = runtime::GetDataSize(lhs->dl_tensor); + if (compare_data) { + return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0; + } else { + return true; + } + } else { + return false; + } +} + +bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs, + const runtime::NDArray::Container* rhs, + SEqualReducer equal) { + return NDArrayEqual(lhs, rhs, equal, true); +} + /*! * \brief A non recursive stack based SEqual handler that can remaps vars. * @@ -362,19 +395,30 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { if (equal_map_rhs_.count(rhs)) return false; // Run reduce check for free nodes. - if (!IsPathTracingEnabled()) { - return vtable_->SEqualReduce(lhs.get(), rhs.get(), - SEqualReducer(this, nullptr, map_free_vars)); + SEqualReducer reducer = GetReducer(lhs, rhs, map_free_vars, current_paths); + + if (auto lhs_ptr = lhs.as(), + rhs_ptr = rhs.as(); + lhs_ptr && rhs_ptr) { + return NDArrayEqual(lhs_ptr, rhs_ptr, reducer, /*compare_data*/ false); } else { - PathTracingData tracing_data = {current_paths.value(), lhs, rhs, first_mismatch_}; - return vtable_->SEqualReduce(lhs.get(), rhs.get(), - SEqualReducer(this, &tracing_data, map_free_vars)); + return vtable_->SEqualReduce(lhs.get(), rhs.get(), reducer); } }; return CheckResult(compute(), lhs, rhs, current_paths); } private: + SEqualReducer GetReducer(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, + const Optional& current_paths) { + if (!IsPathTracingEnabled()) { + return SEqualReducer(this, nullptr, map_free_vars); + } else { + PathTracingData tracing_data = {current_paths.value(), lhs, rhs, first_mismatch_}; + return SEqualReducer(this, &tracing_data, map_free_vars); + } + } + /*! \brief Pending reduce tasks. */ struct Task { /*! \brief The lhs operand to be compared. */ diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index b40b1751fb78..f6142806603b 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -48,6 +48,21 @@ void ReflectionVTable::SHashReduce(const Object* self, SHashReducer reducer) con fshash_reduce_[tindex](self, reducer); } +void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer& hash_reduce, + bool hash_data) { + ICHECK_EQ(arr->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK(runtime::IsContiguous(arr->dl_tensor)) << "Can only hash contiguous tensor"; + hash_reduce(runtime::DataType(arr->dl_tensor.dtype)); + hash_reduce(arr->dl_tensor.ndim); + for (int i = 0; i < arr->dl_tensor.ndim; ++i) { + hash_reduce(arr->dl_tensor.shape[i]); + } + if (hash_data) { + hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes( + static_cast(arr->dl_tensor.data), runtime::GetDataSize(arr->dl_tensor))); + } +} + // Hash handler that handles free vars // by assigning an unique counter in the order of their ocurrence. // @@ -234,7 +249,12 @@ class VarCountingSHashHandler : public SHashReducer::Handler { // The default equal as registered in the structural equal vtable. void DispatchSHash(const ObjectRef& object, bool map_free_vars) { ICHECK(object.defined()); - vtable_->SHashReduce(object.get(), SHashReducer(this, map_free_vars)); + SHashReducer hash_reduce(this, map_free_vars); + if (auto ndarray = object.as()) { + NDArrayHash(ndarray, hash_reduce, /*hash_data*/ false); + } else { + vtable_->SHashReduce(object.get(), hash_reduce); + } } private: @@ -331,39 +351,7 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait); void NDArrayContainerTrait::SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce) { - ICHECK_EQ(key->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK(runtime::IsContiguous(key->dl_tensor)) << "Can only hash contiguous tensor"; - hash_reduce(runtime::DataType(key->dl_tensor.dtype)); - hash_reduce(key->dl_tensor.ndim); - for (int i = 0; i < key->dl_tensor.ndim; ++i) { - hash_reduce(key->dl_tensor.shape[i]); - } - hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes( - static_cast(key->dl_tensor.data), runtime::GetDataSize(key->dl_tensor))); -} - -bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs, - const runtime::NDArray::Container* rhs, - SEqualReducer equal) { - if (lhs == rhs) return true; - - auto ldt = lhs->dl_tensor.dtype; - auto rdt = rhs->dl_tensor.dtype; - ICHECK_EQ(lhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK_EQ(rhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor"; - ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor"; - - if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; - for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { - if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false; - } - if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { - size_t data_size = runtime::GetDataSize(lhs->dl_tensor); - return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0; - } else { - return false; - } + NDArrayHash(key, hash_reduce, /*hash_data*/ true); } TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait) From d16ecbcb9ac8729968b4e8848bc12063eb2087c8 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 5 Sep 2022 17:08:37 +0900 Subject: [PATCH 2/8] add compare_ndarray_data and hash_ndarray_data arguments --- include/tvm/node/structural_equal.h | 5 +++++ include/tvm/node/structural_hash.h | 4 ++++ python/tvm/ir/base.py | 8 ++++---- rust/tvm-rt/src/object/mod.rs | 4 ++-- src/node/structural_equal.cc | 18 ++++++++++++------ src/node/structural_hash.cc | 14 ++++++++++---- src/relay/backend/te_compiler_cache.h | 5 +++-- src/relay/ir/dataflow_matcher.cc | 4 ++-- 8 files changed, 42 insertions(+), 20 deletions(-) diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index b51021fe4076..9afd683f351d 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -102,6 +102,8 @@ class ObjectPathPair : public ObjectRef { */ class StructuralEqual : public BaseValueEqual { public: + StructuralEqual(bool compare_ndarray_data = true) : compare_ndarray_data_(compare_ndarray_data) {} + // inheritate operator() using BaseValueEqual::operator(); /*! @@ -111,6 +113,9 @@ class StructuralEqual : public BaseValueEqual { * \return The comparison result. */ TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; + + private: + bool compare_ndarray_data_; }; /*! diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index a30a2c59d0d1..c5181468873e 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -74,6 +74,7 @@ class BaseValueHash { */ class StructuralHash : public BaseValueHash { public: + StructuralHash(bool hash_ndarray_data = true) : hash_ndarray_data_(hash_ndarray_data) {} // inheritate operator() using BaseValueHash::operator(); /*! @@ -82,6 +83,9 @@ class StructuralHash : public BaseValueHash { * \return The hash value. */ TVM_DLL size_t operator()(const ObjectRef& key) const; + + private: + bool hash_ndarray_data_; }; /*! diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index c6b30d38edac..6eebef892bc1 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -157,7 +157,7 @@ def save_json(node): return tvm.runtime._ffi_node_api.SaveJSON(node) -def structural_equal(lhs, rhs, map_free_vars=False): +def structural_equal(lhs, rhs, map_free_vars=False, compare_ndarray_data=True): """Check structural equality of lhs and rhs. The structural equality is recursively defined in the DAG of IRNodes. @@ -206,7 +206,7 @@ def structural_equal(lhs, rhs, map_free_vars=False): """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) - return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars)) + return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars, compare_ndarray_data)) def get_first_structural_mismatch(lhs, rhs, map_free_vars=False): @@ -267,7 +267,7 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False): tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars) -def structural_hash(node, map_free_vars=False): +def structural_hash(node, map_free_vars=False, hash_ndarray_data=True): """Compute structural hash of node The structural hash value is recursively defined in the DAG of IRNodes. @@ -306,4 +306,4 @@ def structural_hash(node, map_free_vars=False): -------- structrual_equal """ - return tvm.runtime._ffi_node_api.StructuralHash(node, map_free_vars) + return tvm.runtime._ffi_node_api.StructuralHash(node, map_free_vars, hash_ndarray_data) diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index f5832fcb3ab8..4be818e15ae6 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -104,7 +104,7 @@ external! { #[name("ir.DebugPrint")] pub fn debug_print(object: ObjectRef) -> CString; #[name("node.StructuralHash")] - fn structural_hash(object: ObjectRef, map_free_vars: bool) -> i64; + fn structural_hash(object: ObjectRef, map_free_vars: bool, hash_ndarray_data : bool) -> i64; #[name("node.StructuralEqual")] - fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool) -> bool; + fn structural_equal(lhs: ObjectRef, rhs: ObjectRef, assert_mode: bool, map_free_vars: bool, compare_ndarray_data : bool) -> bool; } diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 03093847cff0..43d0dbf55975 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -233,8 +233,11 @@ bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs, */ class RemapVarSEqualHandler : public SEqualReducer::Handler { public: - explicit RemapVarSEqualHandler(bool assert_mode, Optional* first_mismatch) - : assert_mode_(assert_mode), first_mismatch_(first_mismatch) {} + explicit RemapVarSEqualHandler(bool assert_mode, Optional* first_mismatch, + bool compare_ndarray_data = true) + : assert_mode_(assert_mode), + first_mismatch_(first_mismatch), + compare_ndarray_data_(compare_ndarray_data) {} bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, const Optional& current_paths) final { @@ -400,7 +403,7 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { if (auto lhs_ptr = lhs.as(), rhs_ptr = rhs.as(); lhs_ptr && rhs_ptr) { - return NDArrayEqual(lhs_ptr, rhs_ptr, reducer, /*compare_data*/ false); + return NDArrayEqual(lhs_ptr, rhs_ptr, reducer, compare_ndarray_data_); } else { return vtable_->SEqualReduce(lhs.get(), rhs.get(), reducer); } @@ -467,12 +470,15 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { std::unordered_map equal_map_lhs_; // map from rhs to lhs std::unordered_map equal_map_rhs_; + // TODO + bool compare_ndarray_data_; }; TVM_REGISTER_GLOBAL("node.StructuralEqual") .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode, - bool map_free_vars) { - return RemapVarSEqualHandler(assert_mode, nullptr).Equal(lhs, rhs, map_free_vars); + bool map_free_vars, bool compare_ndarray_data) { + return RemapVarSEqualHandler(assert_mode, nullptr, compare_ndarray_data) + .Equal(lhs, rhs, map_free_vars); }); TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch") @@ -484,7 +490,7 @@ TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch") }); bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { - return RemapVarSEqualHandler(false, nullptr).Equal(lhs, rhs, false); + return RemapVarSEqualHandler(false, nullptr, compare_ndarray_data_).Equal(lhs, rhs, false); } } // namespace tvm diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index f6142806603b..5126cb7e4fda 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -72,6 +72,9 @@ void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer& hash_redu class VarCountingSHashHandler : public SHashReducer::Handler { public: + explicit VarCountingSHashHandler(bool hash_ndarray_data = true) + : hash_ndarray_data_(hash_ndarray_data) {} + /*! \brief Pending reduce tasks. */ struct Task { /*! @@ -251,7 +254,7 @@ class VarCountingSHashHandler : public SHashReducer::Handler { ICHECK(object.defined()); SHashReducer hash_reduce(this, map_free_vars); if (auto ndarray = object.as()) { - NDArrayHash(ndarray, hash_reduce, /*hash_data*/ false); + NDArrayHash(ndarray, hash_reduce, hash_ndarray_data_); } else { vtable_->SHashReduce(object.get(), hash_reduce); } @@ -274,16 +277,19 @@ class VarCountingSHashHandler : public SHashReducer::Handler { ReflectionVTable* vtable_ = ReflectionVTable::Global(); // map from lhs to rhs std::unordered_map hash_memo_; + // TODO + bool hash_ndarray_data_; }; TVM_REGISTER_GLOBAL("node.StructuralHash") - .set_body_typed([](const ObjectRef& object, bool map_free_vars) -> int64_t { - size_t hashed_value = VarCountingSHashHandler().Hash(object, map_free_vars); + .set_body_typed([](const ObjectRef& object, bool map_free_vars, + bool hash_ndarray_data) -> int64_t { + size_t hashed_value = VarCountingSHashHandler(hash_ndarray_data).Hash(object, map_free_vars); return static_cast(hashed_value); }); size_t StructuralHash::operator()(const ObjectRef& object) const { - return VarCountingSHashHandler().Hash(object, false); + return VarCountingSHashHandler(hash_ndarray_data_).Hash(object, false); } // SEQualReduce traits for runtime containers. diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 95c5bc974181..af8f9fc87831 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -238,7 +238,7 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, inline size_t CCacheKeyNode::Hash() const { if (hash_ != 0) return hash_; // do structral hash, avoid 0. - hash_ = tvm::StructuralHash()(this->source_func); + hash_ = tvm::StructuralHash(/*hash_ndarray_data*/ false)(this->source_func); hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); if (hash_ == 0) hash_ = 1; return hash_; @@ -248,7 +248,8 @@ inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { if (Hash() != other->Hash()) return false; return this->target->str() == other->target->str() && this->virtual_device == other->virtual_device && - tvm::StructuralEqual()(this->source_func, other->source_func); + tvm::StructuralEqual(/*compare_ndarray_data*/ false)(this->source_func, + other->source_func); } } // namespace tec diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 42fec9e27af2..1f37e52daa9d 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -111,7 +111,7 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { // Compare the objects for structural equality static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); ICHECK(structural_equal) << "node.StructuralEqual is not registered."; - if ((*structural_equal)(lhs, GetRef(rhs.ptr()), false, true)) { + if ((*structural_equal)(lhs, GetRef(rhs.ptr()), false, true, true)) { return true; } } @@ -815,7 +815,7 @@ Expr PatternRewriter::Rewrite(const Array& callbacks, const E VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post); count++; } - equal = (*structural_equal)(last, post, false, true); + equal = (*structural_equal)(last, post, false, true, true); } while (!equal && count < 100 && !callback_->rewrite_once); if (count >= 100) { LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?"; From 9adc317d0fe55af9e251bb51610ffe373f3f4fb3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 6 Sep 2022 09:38:42 +0900 Subject: [PATCH 3/8] add test --- python/tvm/ir/base.py | 6 ++- .../relay/test_ir_structural_equal_hash.py | 49 +++++++++++++++++++ .../test_meta_schedule_integration.py | 17 +++++++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 6eebef892bc1..8eace7762a43 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -206,7 +206,11 @@ def structural_equal(lhs, rhs, map_free_vars=False, compare_ndarray_data=True): """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) - return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars, compare_ndarray_data)) + return bool( + tvm.runtime._ffi_node_api.StructuralEqual( + lhs, rhs, False, map_free_vars, compare_ndarray_data + ) + ) def get_first_structural_mismatch(lhs, rhs, map_free_vars=False): diff --git a/tests/python/relay/test_ir_structural_equal_hash.py b/tests/python/relay/test_ir_structural_equal_hash.py index a808259d26af..3fe85f5bcd0b 100644 --- a/tests/python/relay/test_ir_structural_equal_hash.py +++ b/tests/python/relay/test_ir_structural_equal_hash.py @@ -796,5 +796,54 @@ def func3(): assert not tvm.ir.structural_equal(func1(), func3()) +def test_ignore_ndarray_data(): + a = relay.const([1.0, 2.0], dtype="float32") + b = relay.const([2.0, 1.0], dtype="float32") + assert not tvm.ir.structural_equal(a, b) + assert not tvm.ir.structural_hash(a) == tvm.ir.structural_hash(b) + assert tvm.ir.structural_equal(a, b, compare_ndarray_data=False) + assert tvm.ir.structural_hash(a, hash_ndarray_data=False) == tvm.ir.structural_hash( + b, hash_ndarray_data=False + ) + + def get_conv2d_nchw( + d_shape, + w_shape, + ): + data = relay.var("data", shape=d_shape, dtype="float16") + weight = relay.var("weight", shape=w_shape, dtype="float16") + out_channel = w_shape[0] + return relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=w_shape[2:], + channels=out_channel, + padding=(1, 1), + strides=(1, 1), + out_dtype="float16", + ) + + data_shape = (1, 128, 28, 28) + weight_shape = (128, 128, 1, 1) + + conv2d = get_conv2d_nchw(data_shape, weight_shape) + mod1 = tvm.IRModule.from_expr(conv2d) + mod2 = tvm.IRModule.from_expr(conv2d) + + params = {"weight": np.random.randn(*weight_shape).astype("float16")} + mod1 = relay.build_module.bind_params_by_name(mod1["main"], params) + + params = {"weight": np.random.randn(*weight_shape).astype("float16")} + mod2 = relay.build_module.bind_params_by_name(mod2["main"], params) + + assert not tvm.ir.structural_equal(mod1, mod2) + assert not tvm.ir.structural_hash(mod1) == tvm.ir.structural_hash(mod2) + + assert tvm.ir.structural_equal(mod1, mod2, compare_ndarray_data=False) + assert tvm.ir.structural_hash(mod1, hash_ndarray_data=False) == tvm.ir.structural_hash( + mod2, hash_ndarray_data=False + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index 366a2e4887ed..b3726993432f 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -339,5 +339,22 @@ def test_extract_task_arm_conv2d_nchwc(): assert list(out_type.shape) == [1, 8, 130, 130, 4] +@requires_torch +def test_link_params(): + target = "llvm --num-cores=10" + mod, params, _ = get_network(name="resnet_50", input_shape=[1, 3, 224, 224]) + + pass_config = {"relay.FuseOps.link_params": True, + "relay.backend.use_meta_schedule": True, + "relay.backend.tir_converter": "default" + } + + extracted_tasks = ms.extract_task_from_relay(mod, target, params, pass_config=pass_config) + + conv2d_tasks = list(filter(lambda task: "conv2d" in task.task_name, extracted_tasks)) + + assert len(conv2d_tasks) == 24 + + if __name__ == "__main__": tvm.testing.main() From d307137cff5f680fc51b01e018c50e34772c88ac Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 6 Sep 2022 09:56:07 +0900 Subject: [PATCH 4/8] ignore ndarray data in MS database lookup --- src/meta_schedule/database/database.cc | 4 ++-- src/meta_schedule/database/json_database.cc | 7 ++++--- src/meta_schedule/database/memory_database.cc | 4 ++-- tests/python/unittest/test_meta_schedule_integration.py | 9 +++++---- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index d082ff7a3901..d79caf692269 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -25,7 +25,7 @@ namespace meta_schedule { Workload::Workload(IRModule mod) { ObjectPtr n = runtime::make_object(); - n->shash = tvm::StructuralHash()(mod); + n->shash = tvm::StructuralHash(/*hash_ndarray_data*/ false)(mod); n->mod = mod; data_ = std::move(n); } @@ -61,7 +61,7 @@ Workload Workload::FromJSON(const ObjectRef& json_obj) { mod = Downcast(LoadJSON(json_mod)); } // Verify SHash(mod) == shash - shash = tvm::StructuralHash()(mod); + shash = tvm::StructuralHash(/*hash_ndarray_data*/ false)(mod); String recalc_shash = SHash2Str(shash); CHECK_EQ(recalc_shash, str_shash) << "ValueError: Structural hash changed. Given: " << str_shash << "; Recalculated: " << recalc_shash; diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 91b96c82479f..0262cfcf715a 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -88,13 +88,14 @@ class JSONDatabaseNode : public DatabaseNode { public: bool HasWorkload(const IRModule& mod) { - return workloads2idx_.find(Workload(mod, tvm::StructuralHash()(mod))) != workloads2idx_.end(); + return workloads2idx_.find(Workload( + mod, tvm::StructuralHash(/*hash_ndarray_data*/ false)(mod))) != workloads2idx_.end(); } Workload CommitWorkload(const IRModule& mod) { // Try to insert `mod` into `workloads_` - auto [it, inserted] = - this->workloads2idx_.emplace(Workload(mod, tvm::StructuralHash()(mod)), -1); + auto [it, inserted] = this->workloads2idx_.emplace( + Workload(mod, tvm::StructuralHash(/*hash_ndarray_data*/ false)(mod)), -1); Workload workload = it->first; // If `mod` is new in `workloads2idx_`, append it to the workload file if (inserted) { diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index b6c635555152..bbab8944dfc9 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -37,7 +37,7 @@ class MemoryDatabaseNode : public DatabaseNode { public: bool HasWorkload(const IRModule& mod) final { for (const auto& workload : workloads) { - if (StructuralEqual()(workload->mod, mod)) { + if (StructuralEqual(/*compare_ndarray_data*/ false)(workload->mod, mod)) { return true; } } @@ -46,7 +46,7 @@ class MemoryDatabaseNode : public DatabaseNode { Workload CommitWorkload(const IRModule& mod) final { for (const auto& workload : workloads) { - if (StructuralEqual()(workload->mod, mod)) { + if (StructuralEqual(/*compare_ndarray_data*/false)(workload->mod, mod)) { return workload; } } diff --git a/tests/python/unittest/test_meta_schedule_integration.py b/tests/python/unittest/test_meta_schedule_integration.py index b3726993432f..dc8ff21babc2 100644 --- a/tests/python/unittest/test_meta_schedule_integration.py +++ b/tests/python/unittest/test_meta_schedule_integration.py @@ -344,10 +344,11 @@ def test_link_params(): target = "llvm --num-cores=10" mod, params, _ = get_network(name="resnet_50", input_shape=[1, 3, 224, 224]) - pass_config = {"relay.FuseOps.link_params": True, - "relay.backend.use_meta_schedule": True, - "relay.backend.tir_converter": "default" - } + pass_config = { + "relay.FuseOps.link_params": True, + "relay.backend.use_meta_schedule": True, + "relay.backend.tir_converter": "default", + } extracted_tasks = ms.extract_task_from_relay(mod, target, params, pass_config=pass_config) From 5da43310f42001db73bb2b9db0cfa7b8990a66d7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 6 Sep 2022 10:13:40 +0900 Subject: [PATCH 5/8] add doc --- include/tvm/node/structural_equal.h | 5 +++++ include/tvm/node/structural_hash.h | 6 ++++++ python/tvm/ir/base.py | 13 +++++++++++-- src/meta_schedule/database/memory_database.cc | 2 +- src/node/structural_equal.cc | 2 +- src/node/structural_hash.cc | 2 +- 6 files changed, 25 insertions(+), 5 deletions(-) diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 9afd683f351d..74ea9654e1e2 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -102,6 +102,10 @@ class ObjectPathPair : public ObjectRef { */ class StructuralEqual : public BaseValueEqual { public: + /*! + * \brief Constructor + * \param compare_ndarray_data Whether or not we compare ndarray data to determine equality. + */ StructuralEqual(bool compare_ndarray_data = true) : compare_ndarray_data_(compare_ndarray_data) {} // inheritate operator() @@ -115,6 +119,7 @@ class StructuralEqual : public BaseValueEqual { TVM_DLL bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const; private: + /*! \brief Whether or not we compare ndarray data to determine equality. */ bool compare_ndarray_data_; }; diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index c5181468873e..219a51e6ac12 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -74,7 +74,12 @@ class BaseValueHash { */ class StructuralHash : public BaseValueHash { public: + /*! + * \brief Constructor + * \param hash_ndarray_data Whether or not we hash ndarray data. + */ StructuralHash(bool hash_ndarray_data = true) : hash_ndarray_data_(hash_ndarray_data) {} + // inheritate operator() using BaseValueHash::operator(); /*! @@ -85,6 +90,7 @@ class StructuralHash : public BaseValueHash { TVM_DLL size_t operator()(const ObjectRef& key) const; private: + /*! \brief Whether or not we hash ndarray data. */ bool hash_ndarray_data_; }; diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 8eace7762a43..b86721f3a992 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -194,6 +194,9 @@ def structural_equal(lhs, rhs, map_free_vars=False, compare_ndarray_data=True): Whether free variables (i.e. variables without a definition site) should be mapped as equal to each other. + compare_ndarray_data : bool + Whether or not we compare ndarray data to determine equality. + Return ------ result : bool @@ -243,7 +246,7 @@ def get_first_structural_mismatch(lhs, rhs, map_free_vars=False): return mismatch.lhs_path, mismatch.rhs_path -def assert_structural_equal(lhs, rhs, map_free_vars=False): +def assert_structural_equal(lhs, rhs, map_free_vars=False, compare_ndarray_data=True): """Assert lhs and rhs are structurally equal to each other. Parameters @@ -258,6 +261,9 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False): Whether or not shall we map free vars that does not bound to any definitions as equal to each other. + compare_ndarray_data : bool + Whether or not we compare ndarray data to determine equality. + Raises ------ ValueError : if assertion does not hold. @@ -268,7 +274,7 @@ def assert_structural_equal(lhs, rhs, map_free_vars=False): """ lhs = tvm.runtime.convert(lhs) rhs = tvm.runtime.convert(rhs) - tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars) + tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars, compare_ndarray_data) def structural_hash(node, map_free_vars=False, hash_ndarray_data=True): @@ -301,6 +307,9 @@ def structural_hash(node, map_free_vars=False, hash_ndarray_data=True): by the order of their occurrences. Otherwise, we will hash by their in-memory pointer address. + hash_ndarray_data : bool + Whether or not we hash ndarray data. + Return ------ result : int diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index bbab8944dfc9..043e9a0c09a1 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -46,7 +46,7 @@ class MemoryDatabaseNode : public DatabaseNode { Workload CommitWorkload(const IRModule& mod) final { for (const auto& workload : workloads) { - if (StructuralEqual(/*compare_ndarray_data*/false)(workload->mod, mod)) { + if (StructuralEqual(/*compare_ndarray_data*/ false)(workload->mod, mod)) { return workload; } } diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 43d0dbf55975..32a9dab2a108 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -470,7 +470,7 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler { std::unordered_map equal_map_lhs_; // map from rhs to lhs std::unordered_map equal_map_rhs_; - // TODO + // Whether or not compare ndarray raw data bool compare_ndarray_data_; }; diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 5126cb7e4fda..bad7949dc23b 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -277,7 +277,7 @@ class VarCountingSHashHandler : public SHashReducer::Handler { ReflectionVTable* vtable_ = ReflectionVTable::Global(); // map from lhs to rhs std::unordered_map hash_memo_; - // TODO + // Whether or not hash ndarray raw data bool hash_ndarray_data_; }; From a73d79050256d9083cbb756577855ab60ef2f044 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 6 Sep 2022 10:46:46 +0900 Subject: [PATCH 6/8] Do not hard code ignore option in CCachekey --- include/tvm/node/structural_equal.h | 3 ++- include/tvm/node/structural_hash.h | 2 +- src/node/structural_hash.cc | 17 +++++++++-------- src/relay/backend/task_extraction.cc | 2 +- src/relay/backend/te_compiler.cc | 2 +- src/relay/backend/te_compiler_cache.cc | 3 ++- src/relay/backend/te_compiler_cache.h | 12 ++++++++---- 7 files changed, 24 insertions(+), 17 deletions(-) diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 74ea9654e1e2..7c9081e96baa 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -106,7 +106,8 @@ class StructuralEqual : public BaseValueEqual { * \brief Constructor * \param compare_ndarray_data Whether or not we compare ndarray data to determine equality. */ - StructuralEqual(bool compare_ndarray_data = true) : compare_ndarray_data_(compare_ndarray_data) {} + explicit StructuralEqual(bool compare_ndarray_data = true) + : compare_ndarray_data_(compare_ndarray_data) {} // inheritate operator() using BaseValueEqual::operator(); diff --git a/include/tvm/node/structural_hash.h b/include/tvm/node/structural_hash.h index 219a51e6ac12..bd106f85cab2 100644 --- a/include/tvm/node/structural_hash.h +++ b/include/tvm/node/structural_hash.h @@ -78,7 +78,7 @@ class StructuralHash : public BaseValueHash { * \brief Constructor * \param hash_ndarray_data Whether or not we hash ndarray data. */ - StructuralHash(bool hash_ndarray_data = true) : hash_ndarray_data_(hash_ndarray_data) {} + explicit StructuralHash(bool hash_ndarray_data = true) : hash_ndarray_data_(hash_ndarray_data) {} // inheritate operator() using BaseValueHash::operator(); diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index bad7949dc23b..86e8e96a87a5 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -48,18 +48,19 @@ void ReflectionVTable::SHashReduce(const Object* self, SHashReducer reducer) con fshash_reduce_[tindex](self, reducer); } -void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer& hash_reduce, +void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce, bool hash_data) { ICHECK_EQ(arr->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; ICHECK(runtime::IsContiguous(arr->dl_tensor)) << "Can only hash contiguous tensor"; - hash_reduce(runtime::DataType(arr->dl_tensor.dtype)); - hash_reduce(arr->dl_tensor.ndim); + (*hash_reduce)(runtime::DataType(arr->dl_tensor.dtype)); + (*hash_reduce)(arr->dl_tensor.ndim); for (int i = 0; i < arr->dl_tensor.ndim; ++i) { - hash_reduce(arr->dl_tensor.shape[i]); + (*hash_reduce)(arr->dl_tensor.shape[i]); } if (hash_data) { - hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes( - static_cast(arr->dl_tensor.data), runtime::GetDataSize(arr->dl_tensor))); + (*hash_reduce) + ->SHashReduceHashedValue(runtime::String::HashBytes( + static_cast(arr->dl_tensor.data), runtime::GetDataSize(arr->dl_tensor))); } } @@ -254,7 +255,7 @@ class VarCountingSHashHandler : public SHashReducer::Handler { ICHECK(object.defined()); SHashReducer hash_reduce(this, map_free_vars); if (auto ndarray = object.as()) { - NDArrayHash(ndarray, hash_reduce, hash_ndarray_data_); + NDArrayHash(ndarray, &hash_reduce, hash_ndarray_data_); } else { vtable_->SHashReduce(object.get(), hash_reduce); } @@ -357,7 +358,7 @@ TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait); void NDArrayContainerTrait::SHashReduce(const runtime::NDArray::Container* key, SHashReducer hash_reduce) { - NDArrayHash(key, hash_reduce, /*hash_data*/ true); + NDArrayHash(key, &hash_reduce, /*hash_data*/ true); } TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait) diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 213841c621de..74577c973eb0 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -49,7 +49,7 @@ Array ExtractTask(IRModule mod, Target target, if (!relay_func->HasNonzeroAttr(attr::kPrimitive)) { return; } - tec::CCacheKey cache_key(relay_func, target); + tec::CCacheKey cache_key(relay_func, target, /*ignore_ndarray_data*/ true); auto it = cache.find(cache_key); if (it != cache.end()) { it->second->weight += 1; diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 8fa8610c0fca..a7efe969e11b 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -918,7 +918,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } else { // Cases 1 and 2: lower the primitive function for the desired target, possibly using external // codegen. - CCacheKey key(Downcast(primitive_func), target, + CCacheKey key(Downcast(primitive_func), target, /*ignore_ndarray_data*/ false, GetVirtualDevice(GetRef(call_node))); CachedFunc cfunc = compiler_->Lower(key); ICHECK(cfunc.defined()); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 1d7566ebe2bd..e7b3644384e4 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -69,10 +69,11 @@ LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation im data_ = std::move(n); } -CCacheKey::CCacheKey(Function source_func, Target target, VirtualDevice vd) { +CCacheKey::CCacheKey(Function source_func, Target target, bool ignore_ndarray_data, VirtualDevice vd) { auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); + n->ignore_ndarray_data = ignore_ndarray_data; n->virtual_device = std::move(vd); data_ = std::move(n); } diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index af8f9fc87831..83976341010d 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -86,10 +86,13 @@ class CCacheKeyNode : public Object { Target target; /*! \brief The virtual device constrains.*/ VirtualDevice virtual_device; + /*! \brief Whether or not we ignore ndarray raw data when comparing and hashing them. */ + bool ignore_ndarray_data; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("source_func", &source_func); v->Visit("target", &target); + v->Visit("ignore_ndarray_data", &ignore_ndarray_data); v->Visit("virtual_device", &virtual_device); } /*! \return The hash value of CCacheKey. */ @@ -121,8 +124,10 @@ class CCacheKey : public ObjectRef { * \brief The constructor * \param source_func The source function. * \param target The target device. + * \param ignore_ndarray_data Whether or not we ignore ndarray raw data when comparing and hashing + * them */ - TVM_DLL CCacheKey(Function source_func, Target target, + TVM_DLL CCacheKey(Function source_func, Target target, bool ignore_ndarray_data = false, VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained()); const CCacheKeyNode* operator->() const { return static_cast(get()); } @@ -238,7 +243,7 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, inline size_t CCacheKeyNode::Hash() const { if (hash_ != 0) return hash_; // do structral hash, avoid 0. - hash_ = tvm::StructuralHash(/*hash_ndarray_data*/ false)(this->source_func); + hash_ = tvm::StructuralHash(!ignore_ndarray_data)(this->source_func); hash_ = dmlc::HashCombine(hash_, std::hash()(target->str())); if (hash_ == 0) hash_ = 1; return hash_; @@ -248,8 +253,7 @@ inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { if (Hash() != other->Hash()) return false; return this->target->str() == other->target->str() && this->virtual_device == other->virtual_device && - tvm::StructuralEqual(/*compare_ndarray_data*/ false)(this->source_func, - other->source_func); + tvm::StructuralEqual(!ignore_ndarray_data)(this->source_func, other->source_func); } } // namespace tec From 92121448ad3ab921e0393310e5b1e9b4d3d759f0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 8 Sep 2022 17:00:12 +0900 Subject: [PATCH 7/8] add missing ndarray ignore --- include/tvm/meta_schedule/database.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index fa488a38ce0a..f0cb1bb6af96 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -94,7 +94,7 @@ struct WorkloadHash { /*! \brief The equality check for Workload */ struct WorkloadEqual { bool operator()(const Workload& a, const Workload& b) const { - return a->shash == b->shash && tvm::StructuralEqual()(a->mod, b->mod); + return a->shash == b->shash && tvm::StructuralEqual(false)(a->mod, b->mod); } }; From 8c5ef429af313feedd02fdae457404d1eced6a92 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 8 Sep 2022 19:18:07 +0900 Subject: [PATCH 8/8] more fix to workload lookup. The result is now correct --- src/relay/backend/te_compiler_cache.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index e7b3644384e4..496e26102fd0 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -366,8 +366,9 @@ class ScheduleBuilder : public ExprVisitor { constants.push_back(const_node->data); } if (Optional f = tir_converter(te_args, constants)) { + IRModule query_mod = backend::PrimFuncToIRModule(f.value()); if (Optional opt_record = database_.value()->QueryTuningRecord( - /*mod=*/backend::PrimFuncToIRModule(f.value()), + /*mod=*/query_mod, /*target=*/target_, /*workload_name=*/prim_fn_var->name_hint)) { static InstructionKind kind_transform_layout = InstructionKind::Get("TransformLayout"); @@ -378,7 +379,7 @@ class ScheduleBuilder : public ExprVisitor { MetaScheduleLayoutRewriter::LayoutQueuePush(Downcast(inst->attrs[2])); } } - Schedule sch = Schedule::Traced(record->workload->mod, /*seed=*/-1, /*debug_mask=*/0, + Schedule sch = Schedule::Traced(query_mod, /*seed=*/-1, /*debug_mask=*/0, tir::ScheduleErrorRenderLevel::kDetail); record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false); IRModule mod = sch->mod();