diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index bcdffe9ff33b..9eead8d5ec31 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -181,6 +181,8 @@ class DatabaseNode : public runtime::Object { * \param mod_eq_name A string to specify the module equality testing and hashing method. * It must be one of the followings: * - "structural": Use StructuralEqual/Hash + * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + * equality testing and hashing. */ explicit DatabaseNode(String mod_eq_name = "structural"); @@ -270,6 +272,8 @@ class PyDatabaseNode : public DatabaseNode { * \param mod_eq_name A string to specify the module equality testing and hashing method. * It must be one of the followings: * - "structural": Use StructuralEqual/Hash + * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + * equality testing and hashing. */ explicit PyDatabaseNode(String mod_eq_name = "structural"); diff --git a/python/tvm/meta_schedule/database/json_database.py b/python/tvm/meta_schedule/database/json_database.py index aedc83ad89b3..f81d8913c18a 100644 --- a/python/tvm/meta_schedule/database/json_database.py +++ b/python/tvm/meta_schedule/database/json_database.py @@ -38,6 +38,8 @@ class JSONDatabase(Database): A string to specify the module equality testing and hashing method. It must be one of the followings: - "structural": Use StructuralEqual/Hash + - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + equality testing and hashing. """ path_workload: str diff --git a/python/tvm/meta_schedule/database/memory_database.py b/python/tvm/meta_schedule/database/memory_database.py index e07f325d9d3d..96b9bb5a0112 100644 --- a/python/tvm/meta_schedule/database/memory_database.py +++ b/python/tvm/meta_schedule/database/memory_database.py @@ -31,6 +31,8 @@ class MemoryDatabase(Database): A string to specify the module equality testing and hashing method. It must be one of the followings: - "structural": Use StructuralEqual/Hash + - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + equality testing and hashing. """ def __init__( diff --git a/python/tvm/meta_schedule/database/schedule_fn_database.py b/python/tvm/meta_schedule/database/schedule_fn_database.py index 273b84185287..7a0b433996c5 100644 --- a/python/tvm/meta_schedule/database/schedule_fn_database.py +++ b/python/tvm/meta_schedule/database/schedule_fn_database.py @@ -37,6 +37,8 @@ class ScheduleFnDatabase(Database): A string to specify the module equality testing and hashing method. It must be one of the followings: - "structural": Use StructuralEqual/Hash + - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + equality testing and hashing. """ def __init__( diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index b9c34e509ab4..089f6e412e20 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -141,6 +141,8 @@ def extract_tasks( A string to specify the module equality testing and hashing method. It must be one of the followings: - "structural": Use StructuralEqual/Hash + - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + equality testing and hashing. Returns ------- @@ -284,6 +286,8 @@ def tune_relay( A string to specify the module equality testing and hashing method. It must be one of the followings: - "structural": Use StructuralEqual/Hash + - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + equality testing and hashing. Returns ------- diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index 66cb60c32902..07021eac3998 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -74,6 +74,8 @@ def tune_tasks( A string to specify the module equality testing and hashing method. It must be one of the followings: - "structural": Use StructuralEqual/Hash + - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + equality testing and hashing. Returns ------- diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index 084ae74bb09c..caa7da170bd6 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -24,6 +24,8 @@ #include +#include "../node/ndarray_hash_equal.h" + namespace tvm { namespace meta_schedule { @@ -33,9 +35,49 @@ class ModuleEqualityStructural : public ModuleEquality { bool Equal(IRModule lhs, IRModule rhs) const { return tvm::StructuralEqual()(lhs, rhs); } }; +class SEqualHandlerIgnoreNDArray : public SEqualHandlerDefault { + public: + SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr) {} + + protected: + bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, + const Optional& current_paths) { + if (auto lhs_ptr = lhs.as(), + rhs_ptr = rhs.as(); + lhs_ptr && rhs_ptr) { + SEqualReducer reducer(this, nullptr, map_free_vars); + return NDArrayEqual(lhs_ptr, rhs_ptr, reducer, false); + } + return SEqualHandlerDefault::DispatchSEqualReduce(lhs, rhs, map_free_vars, current_paths); + } +}; + +class SHashHandlerIgnoreNDArray : public SHashHandlerDefault { + protected: + void DispatchSHash(const ObjectRef& object, bool map_free_vars) override { + ICHECK(object.defined()); + if (auto ndarray = object.as()) { + SHashReducer hash_reduce(this, map_free_vars); + NDArrayHash(ndarray, &hash_reduce, false); + } else { + SHashHandlerDefault::DispatchSHash(object, map_free_vars); + } + } +}; + +class ModuleEqualityIgnoreNDArray : public ModuleEquality { + public: + size_t Hash(IRModule mod) const { return SHashHandlerIgnoreNDArray().Hash(mod, false); } + bool Equal(IRModule lhs, IRModule rhs) const { + return SEqualHandlerIgnoreNDArray().Equal(lhs, rhs, false); + } +}; + std::unique_ptr ModuleEquality::Create(const std::string& mod_eq_name) { if (mod_eq_name == "structural") { return std::make_unique(); + } else if (mod_eq_name == "ignore-ndarray") { + return std::make_unique(); } LOG(FATAL) << "Unknown module equality " << mod_eq_name; return nullptr; diff --git a/src/meta_schedule/module_equality.h b/src/meta_schedule/module_equality.h index 3e6fb55d8a9b..8c99b563551b 100644 --- a/src/meta_schedule/module_equality.h +++ b/src/meta_schedule/module_equality.h @@ -40,6 +40,8 @@ class ModuleEquality { * \param mod_eq_name A string to specify the module equality testing and hashing method. * It must be one of the followings: * - "structural": Use StructuralEqual/Hash + * - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during + * equality testing and hashing. * \return An owning pointer to the created instance */ static std::unique_ptr Create(const std::string& mod_eq_name); diff --git a/src/node/ndarray_hash_equal.h b/src/node/ndarray_hash_equal.h new file mode 100644 index 000000000000..d674018fbdd2 --- /dev/null +++ b/src/node/ndarray_hash_equal.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_NODE_NDARRAY_HASH_EQUAL_H_ +#define TVM_NODE_NDARRAY_HASH_EQUAL_H_ + +#include + +namespace tvm { + +class SEqualReducer; +class SHashReducer; + +/*! + * \brief Test two NDArrays for equality. + * \param lhs The left operand. + * \param rhs The right operand. + * \param equal A Reducer class to reduce the structural equality result of two objects. + * See tvm/node/structural_equal.h. + * \param compare_data Whether or not to consider ndarray raw data in the equality testing. + * \return The equality testing result. + */ +bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs, + SEqualReducer equal, bool compare_data); + +/*! + * \brief Hash NDArray. + * \param arr The NDArray to compute the hash for. + * \param hash_reduce A Reducer class to reduce the structural hash value. + * See tvm/node/structural_hash.h. + * \param hash_data Whether or not to hash ndarray raw data. + */ +void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce, bool hash_data); + +} // namespace tvm + +#endif // TVM_NODE_NDARRAY_HASH_EQUAL_H_ diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 0a9a0ec0bbb7..0290b7afe3fd 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -29,6 +29,8 @@ #include +#include "ndarray_hash_equal.h" + namespace tvm { TVM_REGISTER_OBJECT_TYPE(ObjectPathPairNode); @@ -476,4 +478,37 @@ bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) con return SEqualHandlerDefault(false, nullptr).Equal(lhs, rhs, false); } +bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs, + SEqualReducer equal, bool compare_data) { + if (lhs == rhs) return true; + + auto ldt = lhs->dl_tensor.dtype; + auto rdt = rhs->dl_tensor.dtype; + ICHECK_EQ(lhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK_EQ(rhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor"; + ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor"; + + if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; + for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { + if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false; + } + if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { + size_t data_size = runtime::GetDataSize(lhs->dl_tensor); + if (compare_data) { + return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0; + } else { + return true; + } + } else { + return false; + } +} + +bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs, + const runtime::NDArray::Container* rhs, + SEqualReducer equal) { + return NDArrayEqual(lhs, rhs, equal, true); +} + } // namespace tvm diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index a355e44028b6..1d1185cddc3d 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -35,6 +35,7 @@ #include "../support/base64.h" #include "../support/str_escape.h" #include "../support/utils.h" +#include "ndarray_hash_equal.h" namespace tvm { @@ -359,41 +360,25 @@ struct ADTObjTrait { TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait); -void NDArrayContainerTrait::SHashReduce(const runtime::NDArray::Container* key, - SHashReducer hash_reduce) { - ICHECK_EQ(key->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK(runtime::IsContiguous(key->dl_tensor)) << "Can only hash contiguous tensor"; - hash_reduce(runtime::DataType(key->dl_tensor.dtype)); - hash_reduce(key->dl_tensor.ndim); - for (int i = 0; i < key->dl_tensor.ndim; ++i) { - hash_reduce(key->dl_tensor.shape[i]); - } - hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes( - static_cast(key->dl_tensor.data), runtime::GetDataSize(key->dl_tensor))); +void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce, + bool hash_data) { + ICHECK_EQ(arr->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; + ICHECK(runtime::IsContiguous(arr->dl_tensor)) << "Can only hash contiguous tensor"; + (*hash_reduce)(runtime::DataType(arr->dl_tensor.dtype)); + (*hash_reduce)(arr->dl_tensor.ndim); + for (int i = 0; i < arr->dl_tensor.ndim; ++i) { + (*hash_reduce)(arr->dl_tensor.shape[i]); + } + if (hash_data) { + (*hash_reduce) + ->SHashReduceHashedValue(runtime::String::HashBytes( + static_cast(arr->dl_tensor.data), runtime::GetDataSize(arr->dl_tensor))); + } } -bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs, - const runtime::NDArray::Container* rhs, - SEqualReducer equal) { - if (lhs == rhs) return true; - - auto ldt = lhs->dl_tensor.dtype; - auto rdt = rhs->dl_tensor.dtype; - ICHECK_EQ(lhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK_EQ(rhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor"; - ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor"; - ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor"; - - if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false; - for (int i = 0; i < lhs->dl_tensor.ndim; ++i) { - if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false; - } - if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { - size_t data_size = runtime::GetDataSize(lhs->dl_tensor); - return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0; - } else { - return false; - } +void NDArrayContainerTrait::SHashReduce(const runtime::NDArray::Container* key, + SHashReducer hash_reduce) { + NDArrayHash(key, &hash_reduce, /*bool hash_data*/ true); } TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait) diff --git a/tests/python/unittest/test_meta_schedule_relay_integration.py b/tests/python/unittest/test_meta_schedule_relay_integration.py index d5c81bcc56ba..e9908cbfde14 100644 --- a/tests/python/unittest/test_meta_schedule_relay_integration.py +++ b/tests/python/unittest/test_meta_schedule_relay_integration.py @@ -568,5 +568,64 @@ def test_rewrite_layout_link_params(): np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4) +def test_module_equality_ignore_ndarray(): + target = "llvm --num-cores=4" + + data_shape = (128, 128) + weight_shape1 = (128, 128) + weight_shape2 = (128, 128) + + data = relay.var("data", shape=data_shape, dtype="float32") + weight1 = relay.var("weight1", shape=weight_shape1, dtype="float32") + weight2 = relay.var("weight2", shape=weight_shape2, dtype="float32") + dense1 = relay.nn.dense(data, weight1) + dense2 = relay.nn.dense(dense1, weight2) + mod = tvm.IRModule.from_expr(dense2) + + weight1_np = np.random.randn(*weight_shape1).astype("float32") + weight2_np = np.random.randn(*weight_shape2).astype("float32") + + params = {"weight1": weight1_np, "weight2": weight2_np} + + executor = relay.backend.Executor("graph", {"link-params": True}) + mod = mod.with_attr("executor", executor) + + # Without using ignore-ndarray for module equality, we get duplicated tasks + assert len(ms.relay_integration.extract_tasks(mod, target, params)) == 2 + + module_eqality = "ignore-ndarray" + extracted_tasks = ms.relay_integration.extract_tasks( + mod, target, params, module_equality=module_eqality + ) + + assert len(extracted_tasks) == 1 + + with tempfile.TemporaryDirectory() as work_dir: + tasks, task_weights = ms.relay_integration.extracted_tasks_to_tune_contexts( + extracted_tasks, work_dir, strategy="replay-trace" + ) + database = ms.tune.tune_tasks( + tasks=tasks, + task_weights=task_weights, + work_dir=work_dir, + max_trials_global=4, + module_equality=module_eqality, + ) + lib = ms.relay_integration.compile_relay(database, mod, target, params) + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + data_np = np.random.randn(*data_shape).astype("float32") + + runtime.set_input("data", data_np) + runtime.run() + + out = runtime.get_output(0).numpy() + + ref = np.dot(np.dot(data_np, weight1_np.transpose()), weight2_np.transpose()) + np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4) + + if __name__ == "__main__": tvm.testing.main()