Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/tvm/meta_schedule/database.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ class DatabaseNode : public runtime::Object {
* \param mod_eq_name A string to specify the module equality testing and hashing method.
* It must be one of the followings:
* - "structural": Use StructuralEqual/Hash
* - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
* equality testing and hashing.
*/
explicit DatabaseNode(String mod_eq_name = "structural");

Expand Down Expand Up @@ -270,6 +272,8 @@ class PyDatabaseNode : public DatabaseNode {
* \param mod_eq_name A string to specify the module equality testing and hashing method.
* It must be one of the followings:
* - "structural": Use StructuralEqual/Hash
* - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
* equality testing and hashing.
*/
explicit PyDatabaseNode(String mod_eq_name = "structural");

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/database/json_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class JSONDatabase(Database):
A string to specify the module equality testing and hashing method.
It must be one of the followings:
- "structural": Use StructuralEqual/Hash
- "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
equality testing and hashing.
"""

path_workload: str
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/database/memory_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class MemoryDatabase(Database):
A string to specify the module equality testing and hashing method.
It must be one of the followings:
- "structural": Use StructuralEqual/Hash
- "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
equality testing and hashing.
"""

def __init__(
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/database/schedule_fn_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class ScheduleFnDatabase(Database):
A string to specify the module equality testing and hashing method.
It must be one of the followings:
- "structural": Use StructuralEqual/Hash
- "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
equality testing and hashing.
"""

def __init__(
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def extract_tasks(
A string to specify the module equality testing and hashing method.
It must be one of the followings:
- "structural": Use StructuralEqual/Hash
- "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
equality testing and hashing.

Returns
-------
Expand Down Expand Up @@ -284,6 +286,8 @@ def tune_relay(
A string to specify the module equality testing and hashing method.
It must be one of the followings:
- "structural": Use StructuralEqual/Hash
- "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
equality testing and hashing.

Returns
-------
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def tune_tasks(
A string to specify the module equality testing and hashing method.
It must be one of the followings:
- "structural": Use StructuralEqual/Hash
- "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
equality testing and hashing.

Returns
-------
Expand Down
42 changes: 42 additions & 0 deletions src/meta_schedule/module_equality.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#include <memory>

#include "../node/ndarray_hash_equal.h"

namespace tvm {
namespace meta_schedule {

Expand All @@ -33,9 +35,49 @@ class ModuleEqualityStructural : public ModuleEquality {
bool Equal(IRModule lhs, IRModule rhs) const { return tvm::StructuralEqual()(lhs, rhs); }
};

class SEqualHandlerIgnoreNDArray : public SEqualHandlerDefault {
public:
SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr) {}

protected:
bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const Optional<ObjectPathPair>& current_paths) {
if (auto lhs_ptr = lhs.as<runtime::NDArray::Container>(),
rhs_ptr = rhs.as<runtime::NDArray::Container>();
lhs_ptr && rhs_ptr) {
SEqualReducer reducer(this, nullptr, map_free_vars);
return NDArrayEqual(lhs_ptr, rhs_ptr, reducer, false);
}
return SEqualHandlerDefault::DispatchSEqualReduce(lhs, rhs, map_free_vars, current_paths);
}
};

class SHashHandlerIgnoreNDArray : public SHashHandlerDefault {
protected:
void DispatchSHash(const ObjectRef& object, bool map_free_vars) override {
ICHECK(object.defined());
if (auto ndarray = object.as<runtime::NDArray::Container>()) {
SHashReducer hash_reduce(this, map_free_vars);
NDArrayHash(ndarray, &hash_reduce, false);
} else {
SHashHandlerDefault::DispatchSHash(object, map_free_vars);
}
}
};

class ModuleEqualityIgnoreNDArray : public ModuleEquality {
public:
size_t Hash(IRModule mod) const { return SHashHandlerIgnoreNDArray().Hash(mod, false); }
bool Equal(IRModule lhs, IRModule rhs) const {
return SEqualHandlerIgnoreNDArray().Equal(lhs, rhs, false);
}
};

std::unique_ptr<ModuleEquality> ModuleEquality::Create(const std::string& mod_eq_name) {
if (mod_eq_name == "structural") {
return std::make_unique<ModuleEqualityStructural>();
} else if (mod_eq_name == "ignore-ndarray") {
return std::make_unique<ModuleEqualityIgnoreNDArray>();
}
LOG(FATAL) << "Unknown module equality " << mod_eq_name;
return nullptr;
Expand Down
2 changes: 2 additions & 0 deletions src/meta_schedule/module_equality.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class ModuleEquality {
* \param mod_eq_name A string to specify the module equality testing and hashing method.
* It must be one of the followings:
* - "structural": Use StructuralEqual/Hash
* - "ignore-ndarray": Same as "structural", but ignore ndarray raw data during
* equality testing and hashing.
* \return An owning pointer to the created instance
*/
static std::unique_ptr<ModuleEquality> Create(const std::string& mod_eq_name);
Expand Down
52 changes: 52 additions & 0 deletions src/node/ndarray_hash_equal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_NODE_NDARRAY_HASH_EQUAL_H_
#define TVM_NODE_NDARRAY_HASH_EQUAL_H_

#include <tvm/runtime/ndarray.h>

namespace tvm {

class SEqualReducer;
class SHashReducer;

/*!
* \brief Test two NDArrays for equality.
* \param lhs The left operand.
* \param rhs The right operand.
* \param equal A Reducer class to reduce the structural equality result of two objects.
* See tvm/node/structural_equal.h.
* \param compare_data Whether or not to consider ndarray raw data in the equality testing.
* \return The equality testing result.
*/
bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs,
SEqualReducer equal, bool compare_data);

/*!
* \brief Hash NDArray.
* \param arr The NDArray to compute the hash for.
* \param hash_reduce A Reducer class to reduce the structural hash value.
* See tvm/node/structural_hash.h.
* \param hash_data Whether or not to hash ndarray raw data.
*/
void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce, bool hash_data);

} // namespace tvm

#endif // TVM_NODE_NDARRAY_HASH_EQUAL_H_
35 changes: 35 additions & 0 deletions src/node/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

#include <unordered_map>

#include "ndarray_hash_equal.h"

namespace tvm {

TVM_REGISTER_OBJECT_TYPE(ObjectPathPairNode);
Expand Down Expand Up @@ -476,4 +478,37 @@ bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) con
return SEqualHandlerDefault(false, nullptr).Equal(lhs, rhs, false);
}

bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs,
SEqualReducer equal, bool compare_data) {
if (lhs == rhs) return true;

auto ldt = lhs->dl_tensor.dtype;
auto rdt = rhs->dl_tensor.dtype;
ICHECK_EQ(lhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
ICHECK_EQ(rhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor";
ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor";

if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false;
for (int i = 0; i < lhs->dl_tensor.ndim; ++i) {
if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false;
}
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t data_size = runtime::GetDataSize(lhs->dl_tensor);
if (compare_data) {
return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0;
} else {
return true;
}
} else {
return false;
}
}

bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs,
const runtime::NDArray::Container* rhs,
SEqualReducer equal) {
return NDArrayEqual(lhs, rhs, equal, true);
}

} // namespace tvm
51 changes: 18 additions & 33 deletions src/node/structural_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "../support/base64.h"
#include "../support/str_escape.h"
#include "../support/utils.h"
#include "ndarray_hash_equal.h"

namespace tvm {

Expand Down Expand Up @@ -359,41 +360,25 @@ struct ADTObjTrait {

TVM_REGISTER_REFLECTION_VTABLE(runtime::ADTObj, ADTObjTrait);

void NDArrayContainerTrait::SHashReduce(const runtime::NDArray::Container* key,
SHashReducer hash_reduce) {
ICHECK_EQ(key->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
ICHECK(runtime::IsContiguous(key->dl_tensor)) << "Can only hash contiguous tensor";
hash_reduce(runtime::DataType(key->dl_tensor.dtype));
hash_reduce(key->dl_tensor.ndim);
for (int i = 0; i < key->dl_tensor.ndim; ++i) {
hash_reduce(key->dl_tensor.shape[i]);
}
hash_reduce->SHashReduceHashedValue(runtime::String::HashBytes(
static_cast<const char*>(key->dl_tensor.data), runtime::GetDataSize(key->dl_tensor)));
void NDArrayHash(const runtime::NDArray::Container* arr, SHashReducer* hash_reduce,
bool hash_data) {
ICHECK_EQ(arr->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
ICHECK(runtime::IsContiguous(arr->dl_tensor)) << "Can only hash contiguous tensor";
(*hash_reduce)(runtime::DataType(arr->dl_tensor.dtype));
(*hash_reduce)(arr->dl_tensor.ndim);
for (int i = 0; i < arr->dl_tensor.ndim; ++i) {
(*hash_reduce)(arr->dl_tensor.shape[i]);
}
if (hash_data) {
(*hash_reduce)
->SHashReduceHashedValue(runtime::String::HashBytes(
static_cast<const char*>(arr->dl_tensor.data), runtime::GetDataSize(arr->dl_tensor)));
}
}

bool NDArrayContainerTrait::SEqualReduce(const runtime::NDArray::Container* lhs,
const runtime::NDArray::Container* rhs,
SEqualReducer equal) {
if (lhs == rhs) return true;

auto ldt = lhs->dl_tensor.dtype;
auto rdt = rhs->dl_tensor.dtype;
ICHECK_EQ(lhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
ICHECK_EQ(rhs->dl_tensor.device.device_type, kDLCPU) << "can only compare CPU tensor";
ICHECK(runtime::IsContiguous(lhs->dl_tensor)) << "Can only compare contiguous tensor";
ICHECK(runtime::IsContiguous(rhs->dl_tensor)) << "Can only compare contiguous tensor";

if (lhs->dl_tensor.ndim != rhs->dl_tensor.ndim) return false;
for (int i = 0; i < lhs->dl_tensor.ndim; ++i) {
if (!equal(lhs->dl_tensor.shape[i], rhs->dl_tensor.shape[i])) return false;
}
if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) {
size_t data_size = runtime::GetDataSize(lhs->dl_tensor);
return std::memcmp(lhs->dl_tensor.data, rhs->dl_tensor.data, data_size) == 0;
} else {
return false;
}
void NDArrayContainerTrait::SHashReduce(const runtime::NDArray::Container* key,
SHashReducer hash_reduce) {
NDArrayHash(key, &hash_reduce, /*bool hash_data*/ true);
}

TVM_REGISTER_REFLECTION_VTABLE(runtime::NDArray::Container, NDArrayContainerTrait)
Expand Down
59 changes: 59 additions & 0 deletions tests/python/unittest/test_meta_schedule_relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,5 +568,64 @@ def test_rewrite_layout_link_params():
np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4)


def test_module_equality_ignore_ndarray():
target = "llvm --num-cores=4"

data_shape = (128, 128)
weight_shape1 = (128, 128)
weight_shape2 = (128, 128)

data = relay.var("data", shape=data_shape, dtype="float32")
weight1 = relay.var("weight1", shape=weight_shape1, dtype="float32")
weight2 = relay.var("weight2", shape=weight_shape2, dtype="float32")
dense1 = relay.nn.dense(data, weight1)
dense2 = relay.nn.dense(dense1, weight2)
mod = tvm.IRModule.from_expr(dense2)

weight1_np = np.random.randn(*weight_shape1).astype("float32")
weight2_np = np.random.randn(*weight_shape2).astype("float32")

params = {"weight1": weight1_np, "weight2": weight2_np}

executor = relay.backend.Executor("graph", {"link-params": True})
mod = mod.with_attr("executor", executor)

# Without using ignore-ndarray for module equality, we get duplicated tasks
assert len(ms.relay_integration.extract_tasks(mod, target, params)) == 2

module_eqality = "ignore-ndarray"
extracted_tasks = ms.relay_integration.extract_tasks(
mod, target, params, module_equality=module_eqality
)

assert len(extracted_tasks) == 1

with tempfile.TemporaryDirectory() as work_dir:
tasks, task_weights = ms.relay_integration.extracted_tasks_to_tune_contexts(
extracted_tasks, work_dir, strategy="replay-trace"
)
database = ms.tune.tune_tasks(
tasks=tasks,
task_weights=task_weights,
work_dir=work_dir,
max_trials_global=4,
module_equality=module_eqality,
)
lib = ms.relay_integration.compile_relay(database, mod, target, params)

dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

data_np = np.random.randn(*data_shape).astype("float32")

runtime.set_input("data", data_np)
runtime.run()

out = runtime.get_output(0).numpy()

ref = np.dot(np.dot(data_np, weight1_np.transpose()), weight2_np.transpose())
np.testing.assert_allclose(ref, out, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
tvm.testing.main()