From 0cf641af65f7ba675a0f81a691651c488180da2b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 24 Oct 2018 15:52:00 -0700 Subject: [PATCH 01/15] Add first pass at Relay hashing --- include/tvm/relay/pass.h | 11 ++ src/relay/ir/hash.cc | 332 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 343 insertions(+) create mode 100644 src/relay/ir/hash.cc diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 1b3462659e18..4d488bb41250 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -136,6 +136,17 @@ tvm::Array FreeTypeVars(const Expr& expr); */ Expr DeadCodeElimination(const Expr& e); +/*! \brief Hash a Relay type. + * + */ +size_t HashType(const Expr &); + +/*! \brief Hash a Relay expression. + * + */ +size_t HashExpr(const Expr &); + + } // namespace relay } // namespace tvm #endif // TVM_RELAY_PASS_H_ diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc new file mode 100644 index 000000000000..856645a9a98f --- /dev/null +++ b/src/relay/ir/hash.cc @@ -0,0 +1,332 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file src/tvm/relay/ir/hash.cc + * \brief Hash functions for Relay types and expressions. + */ +#include +#include +#include +#include +#include "type_functor.h" +#include "../../lang/attr_functor.h" + +namespace tvm { +namespace relay { + +// Alpha equal handler for relay. +class RelayHashHandler: + public AttrsHashHandler, + public TypeFunctor, + public ExprFunctor { + public: + explicit RelayHashHandler(bool map_free_var) + : map_free_var_(map_free_var) {} + + /*! + * Check equality of two nodes. + * \param lhs The left hand operand. + * \param rhs The right hand operand. + * \return the compare result. + */ + size_t Hash(const NodeRef& ref) { + if (!ref.defined()) return ref.hash(); + if (ref->derived_from()) { + return TypeHash(Downcast(ref)); + } + if (ref->derived_from()) { + return ExprHash(Downcast(ref)); + } + return AttrHash(ref); + } + + /*! + * Compute hash of the attributes. + * \param ref The attributes. + * \return the hash value + */ + size_t AttrHash(const NodeRef& ref) { + return AttrsHashHandler::Hash(ref); + } + /*! + * Compute hash of a Relay type. + * \param ref The type to hash. + * \param rhs The right hand operand. + * \return the hash value. + */ + size_t TypeHash(const Type& type) { + auto found = hash_map_.find(type); + if (found != hash_map_.end()) { + return found->second; + } else { + auto hash = this->VisitType(type); + hash_map_.insert({type, hash}); + return hash; + } + } + /*! + * Compute the hash of an expression. + * + * \note We run graph structural equality checking when comparing two Exprs. + * This means that AlphaEqualHandler can only be used once for each pair. + * The equality checker checks data-flow equvalence of the Expr DAG. + * This function also runs faster as it memomizes equal_map. + * + * \param expr The expression to hash. + * \return the hash value. + */ + size_t ExprHash(const Expr& expr) { + if (!expr.defined()) return expr.hash(); + auto found = hash_map_.find(expr); + if (found != hash_map_.end()) { + return found->second; + } else { + auto hash = this->VisitExpr(expr); + hash_map_.insert({expr, hash}); + return hash; + } + } + + protected: + /*! + * \brief Check if data type equals each other. + * \param lhs The left hand operand. + * \param rhs The right hand operand. + * \return the compare result. + */ + size_t DataTypeHash(const DataType& dtype) { + return std::hash()( + static_cast(dtype.code()) | + (static_cast(dtype.bits()) << 8) | + (static_cast(dtype.lanes()) << 16)); + } + + /*! + * \brief Check Equality of leaf node of the graph. + * if map_free_var_ is set to true, try to map via equal node. + * \param lhs The left hand operand. + * \param rhs The right hand operand. + * \return the compare result. + */ + size_t LeafNodeEqual(const NodeRef& lhs, const NodeRef& rhs) { + return 0; + // if (lhs.same_as(rhs)) return true; + // auto it = equal_map_.find(lhs); + // if (it != equal_map_.end()) { + // return it->second.same_as(rhs); + // } else { + // if (map_free_var_) { + // if (lhs->type_index() != rhs->type_index()) return false; + // equal_map_[lhs] = rhs; + // return true; + // } else { + // return false; + // } + // } + } + + using AttrsHashHandler::VisitAttr_; + size_t VisitAttr_(const Variable* lhs) final { + return 0; // return LeafNodeEqual(GetRef(lhs), other); + } + + // Type equality + size_t VisitType_(const TensorTypeNode* tensor_type) final { + size_t hash = std::hash()(tensor_type->_type_key); + hash = Combine(hash, DataTypeHash(tensor_type->dtype)); + hash = Combine(hash, Hash(tensor_type->shape)); + return hash; + } + + size_t VisitType_(const IncompleteTypeNode* incomplete) final { + return GetRef(incomplete); + } + + size_t VisitType_(const TypeVarNode* lhs) final { + return 0; + // if (const TypeVarNode* rhs = other.as()) { + // if (lhs->kind != rhs->kind) return false; + // return LeafNodeEqual(GetRef(lhs), other); + // } else { + // return false; + // } + } + + size_t VisitType_(const FuncTypeNode* func_type) final { + size_t hash = std::hash()(func_type->_type_key); + for (auto type_param : func_type->type_params) { + hash = Combine(hash, TypeHash(type_param)); + } + + for (auto arg : func_type->arg_types) { + hash = Combine(hash, TypeHash(arg)); + } + + hash = Combine(hash, TypeHash(func_type->ret_type)); + for (auto cs : func_type->type_constraints) { + hash = Combine(hash, TypeHash(cs)); + } + + return hash; + } + + size_t VisitType_(const TypeRelationNode* type_rel) final { + return GetRef(type_rel).hash(); + // if (const TypeRelationNode* rhs = other.as()) { + // if (lhs->func->name != rhs->func->name) return false; + // if (lhs->num_inputs != rhs->num_inputs) return false; + // if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false; + // if (lhs->args.size() != rhs->args.size()) return false; + // for (size_t i = 0; i < lhs->args.size(); ++i) { + // if (!TypeEqual(lhs->args[i], rhs->args[i])) return false; + // } + // return true; + // } else { + // return false; + // } + } + + size_t VisitType_(const TupleTypeNode* tuple_type) final { + size_t hash = std::hash()(tuple_type->_type_key); + for (size_t i = 0; i < tuple_type->fields.size(); i++) { + hash = Combine(hash, TypeHash(tuple_type->fields[i])); + } + return hash; + } + + // Expr equal checking. + size_t NDArrayHash(const runtime::NDArray& array) { + return 0; + // if (lhs.defined() != rhs.defined()) { + // return false; + // } else if (lhs.same_as(rhs)) { + // return true; + // } else { + // auto ldt = lhs->dtype; + // auto rdt = rhs->dtype; + // CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + // CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + // if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { + // size_t data_size = runtime::GetDataSize(*lhs.operator->()); + // return std::memcmp(lhs->data, rhs->data, data_size) == 0; + // } else { + // return false; + // } + // } + } + + int BindVar(const NodeRef& var) { + var_map_[var] = var_counter; + return var_counter++; + } + + size_t VisitExpr_(const VarNode* var) final { + return std::hash()(var_map_[GetRef(var)]); + } + + size_t VisitExpr_(const GlobalVarNode* global) final { + return GetRef(global).hash(); + } + + size_t VisitExpr_(const TupleNode* tuple) final { + size_t hash = std::hash()(tuple->_type_key); + for (size_t i = 0; i < tuple->fields.size(); i++) { + hash = Combine(hash, ExprHash(tuple->fields[i])); + } + return hash; + } + + size_t VisitExpr_(const FunctionNode* func) final { + size_t hash = std::hash()(func->_type_key); + for (auto type_param : func->type_params) { + hash = Combine(hash, TypeHash(type_param)); + } + + for (auto param : func->params) { + hash = Combine(hash, std::hash()(BindVar(param))); + } + + hash = Combine(hash, TypeHash(func->ret_type)); + hash = Combine(hash, ExprHash(func->body)); + + return hash; + } + + size_t VisitExpr_(const CallNode* call) final { + size_t hash = std::hash()(call->_type_key); + hash = Combine(hash, ExprHash(call->op)); + + for (auto arg : call->args) { + hash = Combine(hash, ExprHash(arg)); + } + + hash = Combine(hash, AttrHash(call->attrs)); + + return hash; + } + + size_t VisitExpr_(const LetNode* let) final { + size_t hash = std::hash()(let->_type_key); + hash = Combine(hash, std::hash()(BindVar(let->var))); + hash = Combine(hash, ExprHash(let->value)); + hash = Combine(hash, ExprHash(let->body)); + return hash; + } + + size_t VisitExpr_(const IfNode* ite) final { + size_t hash = std::hash()(ite->_type_key); + hash = Combine(hash, ExprHash(ite->cond)); + hash = Combine(hash, ExprHash(ite->true_branch)); + hash = Combine(hash, ExprHash(ite->false_branch)); + return hash; + } + + size_t VisitExpr_(const OpNode* op) final { + return GetRef(op).hash(); + } + + size_t VisitExpr_(const ConstantNode* rconst) final { + return NDArrayHash(rconst->data); + } + + size_t VisitExpr_(const TupleGetItemNode* get_item) final { + size_t hash = std::hash()(get_item->_type_key); + hash = Combine(hash, ExprHash(get_item->tuple)); + hash = Combine(hash, std::hash()(get_item->index)); + return hash; + } + + private: + // whether to map open terms. + bool map_free_var_{false}; + // renaming of NodeRef to indicate two nodes equals to each other + std::unordered_map hash_map_; + std::unordered_map var_map_; + int var_counter = 0; +}; + +size_t HashType(const Type& type) { + return RelayHashHandler(false).TypeHash(type); +} + +size_t HashExpr(const Expr& expr) { + return RelayHashHandler(false).ExprHash(expr); +} + +// // TODO(@jroesch): move to correct namespace? +// TVM_REGISTER_API("relay._make._alpha_equal") +// .set_body([](TVMArgs args, TVMRetValue* ret) { +// *ret = RelayHashHandler(false).Hash(args[0]); +// }); + +// TVM_REGISTER_API("relay._make._type_alpha_equal") +// .set_body([](TVMArgs args, TVMRetValue* ret) { +// *ret = RelayHashHandler(false).TypeHash(args[0]); +// }); + +// TVM_REGISTER_API("relay._make._graph_equal") +// .set_body([](TVMArgs args, TVMRetValue* ret) { +// *ret = AlphaEqualHandler(true).Equal(args[0], args[1]); +// }); + +} // namespace relay +} // namespace tvm From e949b9b6b9b1c697c50e96cdccbe22ba50c2d2fb Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 24 Oct 2018 16:19:35 -0700 Subject: [PATCH 02/15] Modify alpha_equal to test hashing too --- python/tvm/relay/ir_pass.py | 3 +++ src/relay/ir/hash.cc | 21 ++++++++++++--------- tests/python/relay/test_pass_alpha_equal.py | 5 ++++- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index c6d5aa7515bc..259a5aca919e 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -169,3 +169,6 @@ def graph_equal(lhs, rhs): True iff lhs is data-flow equivalent to rhs. """ return bool(_make._graph_equal(lhs, rhs)) + +def expr_hash(expr): + return bool(_ir_pass._expr_hash(expr)) diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 856645a9a98f..4338c216fe1f 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -30,6 +30,7 @@ class RelayHashHandler: */ size_t Hash(const NodeRef& ref) { if (!ref.defined()) return ref.hash(); + if (ref->derived_from()) { return TypeHash(Downcast(ref)); } @@ -45,6 +46,7 @@ class RelayHashHandler: * \return the hash value */ size_t AttrHash(const NodeRef& ref) { + if (!ref.defined()) { return ref.hash(); } return AttrsHashHandler::Hash(ref); } /*! @@ -54,6 +56,7 @@ class RelayHashHandler: * \return the hash value. */ size_t TypeHash(const Type& type) { + if (!type.defined()) { return type.hash(); } auto found = hash_map_.find(type); if (found != hash_map_.end()) { return found->second; @@ -312,16 +315,16 @@ size_t HashExpr(const Expr& expr) { return RelayHashHandler(false).ExprHash(expr); } -// // TODO(@jroesch): move to correct namespace? -// TVM_REGISTER_API("relay._make._alpha_equal") -// .set_body([](TVMArgs args, TVMRetValue* ret) { -// *ret = RelayHashHandler(false).Hash(args[0]); -// }); +// TODO(@jroesch): move to correct namespace? +TVM_REGISTER_API("relay._ir_pass._expr_hash") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = static_cast(RelayHashHandler(false).Hash(args[0])); + }); -// TVM_REGISTER_API("relay._make._type_alpha_equal") -// .set_body([](TVMArgs args, TVMRetValue* ret) { -// *ret = RelayHashHandler(false).TypeHash(args[0]); -// }); +TVM_REGISTER_API("relay._ir_pass._type_hash") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = static_cast(RelayHashHandler(false).TypeHash(args[0])); + }); // TVM_REGISTER_API("relay._make._graph_equal") // .set_body([](TVMArgs args, TVMRetValue* ret) { diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index d16c2df53435..0fe0d738bed6 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -1,7 +1,10 @@ import tvm import numpy as np from tvm import relay -from tvm.relay.ir_pass import alpha_equal +from tvm.relay import ir_pass + +def alpha_equal(x, y): + return ir_pass.alpha_equal(x, y) and ir_pass.expr_hash(x) == ir_pass.expr_hash(y) def test_tensor_type_alpha_equal(): t1 = relay.TensorType((3, 4), "float32") From da9b02177596ec43ea42ee4b102af6b3e8a5a0ad Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 24 Oct 2018 16:21:11 -0700 Subject: [PATCH 03/15] Add commentary --- tests/python/relay/test_pass_alpha_equal.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 0fe0d738bed6..3e4cb6cff76b 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -4,7 +4,14 @@ from tvm.relay import ir_pass def alpha_equal(x, y): - return ir_pass.alpha_equal(x, y) and ir_pass.expr_hash(x) == ir_pass.expr_hash(y) + """ + Wrapper around alpha equality which ensures that + the hash function respects equality. + """ + if ir_pass.alpha_equal(x, y): + return ir_pass.expr_hash(x) == ir_pass.expr_hash(y) + else: + return ir_pass.expr_hash(x) != ir_pass.expr_hash(y) def test_tensor_type_alpha_equal(): t1 = relay.TensorType((3, 4), "float32") From 3c7fcba3d71ac8f129ee513772074d71442a793c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 24 Oct 2018 16:53:38 -0700 Subject: [PATCH 04/15] Add case of NDArray --- src/relay/ir/hash.cc | 57 +++++++++++++++++--------------------------- 1 file changed, 22 insertions(+), 35 deletions(-) diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 4338c216fe1f..615a610f4543 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -13,7 +13,7 @@ namespace tvm { namespace relay { -// Alpha equal handler for relay. +// Hash handler for Relay. class RelayHashHandler: public AttrsHashHandler, public TypeFunctor, @@ -23,10 +23,9 @@ class RelayHashHandler: : map_free_var_(map_free_var) {} /*! - * Check equality of two nodes. - * \param lhs The left hand operand. - * \param rhs The right hand operand. - * \return the compare result. + * Compute hash of a node. + * \param ref The node to hash. + * \return the hash value. */ size_t Hash(const NodeRef& ref) { if (!ref.defined()) return ref.hash(); @@ -91,10 +90,9 @@ class RelayHashHandler: protected: /*! - * \brief Check if data type equals each other. - * \param lhs The left hand operand. - * \param rhs The right hand operand. - * \return the compare result. + * \brief Hash a DataType. + * \param dtype The dtype to hash. + * \return the hash value. */ size_t DataTypeHash(const DataType& dtype) { return std::hash()( @@ -132,7 +130,7 @@ class RelayHashHandler: return 0; // return LeafNodeEqual(GetRef(lhs), other); } - // Type equality + // Type hashing size_t VisitType_(const TensorTypeNode* tensor_type) final { size_t hash = std::hash()(tensor_type->_type_key); hash = Combine(hash, DataTypeHash(tensor_type->dtype)); @@ -144,14 +142,9 @@ class RelayHashHandler: return GetRef(incomplete); } - size_t VisitType_(const TypeVarNode* lhs) final { - return 0; - // if (const TypeVarNode* rhs = other.as()) { - // if (lhs->kind != rhs->kind) return false; - // return LeafNodeEqual(GetRef(lhs), other); - // } else { - // return false; - // } + size_t VisitType_(const TypeVarNode* tyvar) final { + int index = BindVar(GetRef(tyvar)); + return std::hash()(index); } size_t VisitType_(const FuncTypeNode* func_type) final { @@ -198,23 +191,17 @@ class RelayHashHandler: // Expr equal checking. size_t NDArrayHash(const runtime::NDArray& array) { - return 0; - // if (lhs.defined() != rhs.defined()) { - // return false; - // } else if (lhs.same_as(rhs)) { - // return true; - // } else { - // auto ldt = lhs->dtype; - // auto rdt = rhs->dtype; - // CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - // CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; - // if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { - // size_t data_size = runtime::GetDataSize(*lhs.operator->()); - // return std::memcmp(lhs->data, rhs->data, data_size) == 0; - // } else { - // return false; - // } - // } + + size_t hash = std::hash()(array->dtype.code); + hash = Combine(hash, std::hash()(array->dtype.bits)); + hash = Combine(hash, std::hash()(array->dtype.lanes)); + CHECK_EQ(array->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + size_t data_size = runtime::GetDataSize(*array.operator->()); + uint8_t * data = reinterpret_cast(array->data); + for (size_t i = 0; i < data_size; i++) { + hash = Combine(hash, std::hash()(data[i])); + } + return hash; } int BindVar(const NodeRef& var) { From 3d85e53cf1ebdda904c7f201a66de8b4b3658a7f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 24 Oct 2018 17:07:35 -0700 Subject: [PATCH 05/15] Add case for type relation --- src/relay/ir/hash.cc | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 615a610f4543..ae6b4b05522a 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -166,19 +166,15 @@ class RelayHashHandler: } size_t VisitType_(const TypeRelationNode* type_rel) final { - return GetRef(type_rel).hash(); - // if (const TypeRelationNode* rhs = other.as()) { - // if (lhs->func->name != rhs->func->name) return false; - // if (lhs->num_inputs != rhs->num_inputs) return false; - // if (!this->AttrEqual(lhs->attrs, rhs->attrs)) return false; - // if (lhs->args.size() != rhs->args.size()) return false; - // for (size_t i = 0; i < lhs->args.size(); ++i) { - // if (!TypeEqual(lhs->args[i], rhs->args[i])) return false; - // } - // return true; - // } else { - // return false; - // } + size_t hash = std::hash()(type_rel->_type_key); + hash = Combine(hash, std::hash()(type_rel->func->name)); + hash = Combine(hash, AttrHash(type_rel->attrs)); + + for (auto arg : type_rel->args) { + hash = Combine(hash, TypeHash(arg)); + } + + return hash; } size_t VisitType_(const TupleTypeNode* tuple_type) final { From 84b1f0329e00e069b9b650f2f334ee12448653c3 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 24 Oct 2018 21:58:33 -0700 Subject: [PATCH 06/15] Adress some code review --- include/tvm/relay/pass.h | 14 ++++++++++++-- src/relay/ir/hash.cc | 41 +++++++++++++++++++--------------------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 4d488bb41250..40793ae297b0 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -138,13 +138,23 @@ Expr DeadCodeElimination(const Expr& e); /*! \brief Hash a Relay type. * + * Implements structural hashing of a Relay type. + * + * \param type the type to hash. + * + * \return the hash value. */ -size_t HashType(const Expr &); +size_t HashType(const Type& type); /*! \brief Hash a Relay expression. * + * Implements structural hashing of a Relay expression. + * + * \param expr the expression to hash. + * + * \return the hash value. */ -size_t HashExpr(const Expr &); +size_t HashExpr(const Expr& expr); } // namespace relay diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index ae6b4b05522a..2275e95918d6 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -8,6 +8,7 @@ #include #include #include "type_functor.h" +#include #include "../../lang/attr_functor.h" namespace tvm { @@ -19,8 +20,7 @@ class RelayHashHandler: public TypeFunctor, public ExprFunctor { public: - explicit RelayHashHandler(bool map_free_var) - : map_free_var_(map_free_var) {} + explicit RelayHashHandler() {} /*! * Compute hash of a node. @@ -95,10 +95,7 @@ class RelayHashHandler: * \return the hash value. */ size_t DataTypeHash(const DataType& dtype) { - return std::hash()( - static_cast(dtype.code()) | - (static_cast(dtype.bits()) << 8) | - (static_cast(dtype.lanes()) << 16)); + return ::tvm::AttrsHash()(dtype); } /*! @@ -149,8 +146,9 @@ class RelayHashHandler: size_t VisitType_(const FuncTypeNode* func_type) final { size_t hash = std::hash()(func_type->_type_key); + for (auto type_param : func_type->type_params) { - hash = Combine(hash, TypeHash(type_param)); + hash = Combine(hash, BindVar(type_param)); } for (auto arg : func_type->arg_types) { @@ -200,17 +198,19 @@ class RelayHashHandler: return hash; } - int BindVar(const NodeRef& var) { - var_map_[var] = var_counter; - return var_counter++; + size_t BindVar(const NodeRef& var) { + size_t hash = std::hash()(var_counter++); + CHECK(hash_map_.find(var) == hash_map_.end()); + hash_map_[var] = hash; + return hash; } size_t VisitExpr_(const VarNode* var) final { - return std::hash()(var_map_[GetRef(var)]); + } size_t VisitExpr_(const GlobalVarNode* global) final { - return GetRef(global).hash(); + return std::hash()(global->name_hint); } size_t VisitExpr_(const TupleNode* tuple) final { @@ -224,11 +224,11 @@ class RelayHashHandler: size_t VisitExpr_(const FunctionNode* func) final { size_t hash = std::hash()(func->_type_key); for (auto type_param : func->type_params) { - hash = Combine(hash, TypeHash(type_param)); + hash = Combine(hash, BindVar(type_param)); } for (auto param : func->params) { - hash = Combine(hash, std::hash()(BindVar(param))); + hash = Combine(hash, BindVar(param)); } hash = Combine(hash, TypeHash(func->ret_type)); @@ -252,7 +252,7 @@ class RelayHashHandler: size_t VisitExpr_(const LetNode* let) final { size_t hash = std::hash()(let->_type_key); - hash = Combine(hash, std::hash()(BindVar(let->var))); + hash = Combine(hash, BindVar(let->var)); hash = Combine(hash, ExprHash(let->value)); hash = Combine(hash, ExprHash(let->body)); return hash; @@ -282,31 +282,28 @@ class RelayHashHandler: } private: - // whether to map open terms. - bool map_free_var_{false}; // renaming of NodeRef to indicate two nodes equals to each other std::unordered_map hash_map_; - std::unordered_map var_map_; int var_counter = 0; }; size_t HashType(const Type& type) { - return RelayHashHandler(false).TypeHash(type); + return RelayHashHandler().TypeHash(type); } size_t HashExpr(const Expr& expr) { - return RelayHashHandler(false).ExprHash(expr); + return RelayHashHandler().ExprHash(expr); } // TODO(@jroesch): move to correct namespace? TVM_REGISTER_API("relay._ir_pass._expr_hash") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = static_cast(RelayHashHandler(false).Hash(args[0])); + *ret = static_cast(RelayHashHandler().Hash(args[0])); }); TVM_REGISTER_API("relay._ir_pass._type_hash") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = static_cast(RelayHashHandler(false).TypeHash(args[0])); + *ret = static_cast(RelayHashHandler().TypeHash(args[0])); }); // TVM_REGISTER_API("relay._make._graph_equal") From f98fc2edf9188ec160ab69518af7587ac3ab383b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 24 Oct 2018 22:11:37 -0700 Subject: [PATCH 07/15] Address 3 more cases --- src/relay/ir/hash.cc | 44 +++++++++++++++----------------------------- 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 2275e95918d6..5fc63d11ae2a 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -7,8 +7,8 @@ #include #include #include -#include "type_functor.h" #include +#include "type_functor.h" #include "../../lang/attr_functor.h" namespace tvm { @@ -98,33 +98,14 @@ class RelayHashHandler: return ::tvm::AttrsHash()(dtype); } - /*! - * \brief Check Equality of leaf node of the graph. - * if map_free_var_ is set to true, try to map via equal node. - * \param lhs The left hand operand. - * \param rhs The right hand operand. - * \return the compare result. - */ - size_t LeafNodeEqual(const NodeRef& lhs, const NodeRef& rhs) { - return 0; - // if (lhs.same_as(rhs)) return true; - // auto it = equal_map_.find(lhs); - // if (it != equal_map_.end()) { - // return it->second.same_as(rhs); - // } else { - // if (map_free_var_) { - // if (lhs->type_index() != rhs->type_index()) return false; - // equal_map_[lhs] = rhs; - // return true; - // } else { - // return false; - // } - // } - } - using AttrsHashHandler::VisitAttr_; - size_t VisitAttr_(const Variable* lhs) final { - return 0; // return LeafNodeEqual(GetRef(lhs), other); + size_t VisitAttr_(const Variable* var) final { + auto it = hash_map_.find(GetRef(var)); + if (it != hash_map_.end()) { + return it->second; + } + + return std::hash()(var->name_hint); } // Type hashing @@ -185,7 +166,6 @@ class RelayHashHandler: // Expr equal checking. size_t NDArrayHash(const runtime::NDArray& array) { - size_t hash = std::hash()(array->dtype.code); hash = Combine(hash, std::hash()(array->dtype.bits)); hash = Combine(hash, std::hash()(array->dtype.lanes)); @@ -202,11 +182,17 @@ class RelayHashHandler: size_t hash = std::hash()(var_counter++); CHECK(hash_map_.find(var) == hash_map_.end()); hash_map_[var] = hash; + + const auto* ty_param = var.as(); + if (ty_param && ty_param->kind == TypeVarNode::Kind::kShapeVar) { + hash_map_[ty_param->var] = hash; + } return hash; } size_t VisitExpr_(const VarNode* var) final { - + size_t name_hash = std::hash()(var->name_hint); + return Combine(name_hash, TypeHash(var->type_annotation)); } size_t VisitExpr_(const GlobalVarNode* global) final { From 69cda63850527f43a31aaa8f748af49bf3250f9e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 24 Oct 2018 22:14:34 -0700 Subject: [PATCH 08/15] Fix Python wrapper for hashing --- python/tvm/relay/ir_pass.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 259a5aca919e..246faeba3348 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -171,4 +171,16 @@ def graph_equal(lhs, rhs): return bool(_make._graph_equal(lhs, rhs)) def expr_hash(expr): - return bool(_ir_pass._expr_hash(expr)) + """Hash a Relay expression structurally. + + Parameters + ---------- + expr: tvm.relay.Expr + The expression to hash. + + Returns + ------- + result: int + The hash value + """ + return int(_ir_pass._expr_hash(expr)) From c13c7a66a80b9e8b6b47f5493eb0d72158a54ef7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 24 Oct 2018 22:25:10 -0700 Subject: [PATCH 09/15] Fix test --- tests/python/relay/test_pass_alpha_equal.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 3e4cb6cff76b..0523b0d1a33d 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -8,10 +8,7 @@ def alpha_equal(x, y): Wrapper around alpha equality which ensures that the hash function respects equality. """ - if ir_pass.alpha_equal(x, y): - return ir_pass.expr_hash(x) == ir_pass.expr_hash(y) - else: - return ir_pass.expr_hash(x) != ir_pass.expr_hash(y) + return ir_pass.alpha_equal(x, y) and ir_pass.expr_hash(x) == ir_pass.expr_hash(y) def test_tensor_type_alpha_equal(): t1 = relay.TensorType((3, 4), "float32") From 43e712002359fad4e1fea067e3716214d3b700e8 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 25 Oct 2018 12:53:34 -0700 Subject: [PATCH 10/15] Address CR feedback --- include/tvm/relay/pass.h | 4 ++-- python/tvm/relay/ir_pass.py | 17 +++++++++++++---- src/relay/ir/hash.cc | 28 ++++++++++++++++------------ 3 files changed, 31 insertions(+), 18 deletions(-) diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 40793ae297b0..bf16c7ed8e33 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -144,7 +144,7 @@ Expr DeadCodeElimination(const Expr& e); * * \return the hash value. */ -size_t HashType(const Type& type); +size_t StructuralHash(const Type& type); /*! \brief Hash a Relay expression. * @@ -154,7 +154,7 @@ size_t HashType(const Type& type); * * \return the hash value. */ -size_t HashExpr(const Expr& expr); +size_t StructuralHash(const Expr& expr); } // namespace relay diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 246faeba3348..03a879daa711 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -1,4 +1,4 @@ -# pylint: disable=no-else-return, +# pylint: disable=no-else-return # pylint: disable=unidiomatic-typecheck """The set of passes for Relay. @@ -7,6 +7,7 @@ """ from . import _ir_pass from . import _make +from .. import Expr, Type # pylint: disable=invalid-name def infer_type(expr, env=None): @@ -148,6 +149,9 @@ def alpha_equal(lhs, rhs): """ return bool(_make._alpha_equal(lhs, rhs)) +lower_ops = _ir_pass.LowerOps +fuse_ops = _ir_pass.FuseOps +monomorph = _ir_pass.Monomorph def graph_equal(lhs, rhs): """Compare two Relay expr for data-flow equivalence. @@ -170,12 +174,12 @@ def graph_equal(lhs, rhs): """ return bool(_make._graph_equal(lhs, rhs)) -def expr_hash(expr): +def structural_hash(value): """Hash a Relay expression structurally. Parameters ---------- - expr: tvm.relay.Expr + expr: tvm.relay.Expr or tvm.relay.Type The expression to hash. Returns @@ -183,4 +187,9 @@ def expr_hash(expr): result: int The hash value """ - return int(_ir_pass._expr_hash(expr)) + if isinstance(value, Expr): + return int(_ir_pass._expr_hash(value)) + elif isinstance(value, Type): + return int(_ir_pass._type_hash(value)) + else: + raise TypeError("found value of type {0} expected relay.Expr or relay.Type".format(type(value))) diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 5fc63d11ae2a..3a0e98fb7d12 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -121,8 +121,18 @@ class RelayHashHandler: } size_t VisitType_(const TypeVarNode* tyvar) final { - int index = BindVar(GetRef(tyvar)); - return std::hash()(index); + /* + TypeVar/Var/Variable have two locations where they are hashed: + + The declaration site of a function, let, or function type. + The first occurence in the term. + + We will only reach this code if the TypeVar itself is unbound, we assign + a free variable index to it, meaning this hashing function implements + structural equality for both open (i.e graph equality) and closed terms + (i.e alpha_equality). + */ + return BindVar(GetRef(tyvar)); } size_t VisitType_(const FuncTypeNode* func_type) final { @@ -164,7 +174,7 @@ class RelayHashHandler: return hash; } - // Expr equal checking. + // Expr hashing. size_t NDArrayHash(const runtime::NDArray& array) { size_t hash = std::hash()(array->dtype.code); hash = Combine(hash, std::hash()(array->dtype.bits)); @@ -180,7 +190,7 @@ class RelayHashHandler: size_t BindVar(const NodeRef& var) { size_t hash = std::hash()(var_counter++); - CHECK(hash_map_.find(var) == hash_map_.end()); + CHECK(hash_map_.count(var) == 0); hash_map_[var] = hash; const auto* ty_param = var.as(); @@ -273,15 +283,14 @@ class RelayHashHandler: int var_counter = 0; }; -size_t HashType(const Type& type) { +size_t StructuralHash(const Type& type) { return RelayHashHandler().TypeHash(type); } -size_t HashExpr(const Expr& expr) { +size_t StructuralHash(const Expr& expr) { return RelayHashHandler().ExprHash(expr); } -// TODO(@jroesch): move to correct namespace? TVM_REGISTER_API("relay._ir_pass._expr_hash") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = static_cast(RelayHashHandler().Hash(args[0])); @@ -292,10 +301,5 @@ TVM_REGISTER_API("relay._ir_pass._type_hash") *ret = static_cast(RelayHashHandler().TypeHash(args[0])); }); -// TVM_REGISTER_API("relay._make._graph_equal") -// .set_body([](TVMArgs args, TVMRetValue* ret) { -// *ret = AlphaEqualHandler(true).Equal(args[0], args[1]); -// }); - } // namespace relay } // namespace tvm From de04837628545b4406913a0dc90131671bc7b652 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 25 Oct 2018 13:00:36 -0700 Subject: [PATCH 11/15] Fix another comment --- src/relay/ir/hash.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 3a0e98fb7d12..4fad17cc5d1e 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -117,7 +117,8 @@ class RelayHashHandler: } size_t VisitType_(const IncompleteTypeNode* incomplete) final { - return GetRef(incomplete); + size_t hash = std::hash()(incomplete->_type_key); + return Combine(hash, std::hash()(incomplete->kind)); } size_t VisitType_(const TypeVarNode* tyvar) final { @@ -190,7 +191,7 @@ class RelayHashHandler: size_t BindVar(const NodeRef& var) { size_t hash = std::hash()(var_counter++); - CHECK(hash_map_.count(var) == 0); + CHECK_EQ(hash_map_.count(var), 0); hash_map_[var] = hash; const auto* ty_param = var.as(); From f5e9c80ee1efd0868e7558c742ae32b440d472d4 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 25 Oct 2018 13:02:59 -0700 Subject: [PATCH 12/15] Fix indentation --- python/tvm/relay/ir_pass.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 03a879daa711..7360d178e850 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -188,8 +188,10 @@ def structural_hash(value): The hash value """ if isinstance(value, Expr): - return int(_ir_pass._expr_hash(value)) + return int(_ir_pass._expr_hash(value)) elif isinstance(value, Type): - return int(_ir_pass._type_hash(value)) + return int(_ir_pass._type_hash(value)) else: - raise TypeError("found value of type {0} expected relay.Expr or relay.Type".format(type(value))) + msg = ("found value of type {0} expected" + + "relay.Expr or relay.Type").format(type(value)) + raise TypeError(msg) From 98c97627dab75f902ad68fbc1803976762e96a64 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 25 Oct 2018 13:09:26 -0700 Subject: [PATCH 13/15] Final feedback --- src/relay/ir/hash.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 4fad17cc5d1e..3aa567a4892e 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -105,7 +105,9 @@ class RelayHashHandler: return it->second; } - return std::hash()(var->name_hint); + + size_t hash = std::hash()(var->_type_key); + return Combine(hash, std::hash()(var->name_hint)); } // Type hashing From 97c85fb49448687a855aaa28c4c3abc5fade6d1d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 25 Oct 2018 13:19:04 -0700 Subject: [PATCH 14/15] Fix import --- python/tvm/relay/ir_pass.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 7360d178e850..04a8a7f687c7 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -7,8 +7,8 @@ """ from . import _ir_pass from . import _make -from .. import Expr, Type -# pylint: disable=invalid-name +from .expr import Expr +from .ty import Type def infer_type(expr, env=None): """Infer the type of expr under the context of env. From 718ff7478458328ab2cb42f448f6c343149b91b0 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 25 Oct 2018 13:20:05 -0700 Subject: [PATCH 15/15] Fix tests --- python/tvm/relay/ir_pass.py | 4 ---- tests/python/relay/test_pass_alpha_equal.py | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 04a8a7f687c7..f930751c41a7 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -149,10 +149,6 @@ def alpha_equal(lhs, rhs): """ return bool(_make._alpha_equal(lhs, rhs)) -lower_ops = _ir_pass.LowerOps -fuse_ops = _ir_pass.FuseOps -monomorph = _ir_pass.Monomorph - def graph_equal(lhs, rhs): """Compare two Relay expr for data-flow equivalence. The difference between this and alpha-equality is that diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 0523b0d1a33d..5158d5c7cc9c 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -8,7 +8,7 @@ def alpha_equal(x, y): Wrapper around alpha equality which ensures that the hash function respects equality. """ - return ir_pass.alpha_equal(x, y) and ir_pass.expr_hash(x) == ir_pass.expr_hash(y) + return ir_pass.alpha_equal(x, y) and ir_pass.structural_hash(x) == ir_pass.structural_hash(y) def test_tensor_type_alpha_equal(): t1 = relay.TensorType((3, 4), "float32")