diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 08fe957d8a78..8f51ea93821d 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -157,16 +157,13 @@ class TypeVar; /*! \brief TypeVar container node */ class TypeVarNode : public TypeNode { public: - /*! - * \brief The variable itself is only meaningful when - * kind is ShapeVar, otherwise, we only use the name. - */ - tvm::Var var; + /*! \brief Name of the variable, it only acts as a hint. */ + std::string name_hint; /*! \brief The kind of type parameter */ Kind kind; void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("var", &var); + v->Visit("name_hint", &name_hint); v->Visit("kind", &kind); v->Visit("span", &span); } @@ -189,16 +186,13 @@ class GlobalTypeVar; /*! \brief GlobalTypeVar container node */ class GlobalTypeVarNode : public TypeNode { public: - /*! - * \brief The variable itself is only meaningful when - * kind is ShapeVar; otherwise, we only use the name. - */ - tvm::Var var; + /*! \brief Name of the variable, it only acts as a hint. */ + std::string name_hint; /*! \brief The kind of type parameter */ Kind kind; void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("var", &var); + v->Visit("name_hint", &name_hint); v->Visit("kind", &kind); v->Visit("span", &span); } diff --git a/python/tvm/relay/_parser.py b/python/tvm/relay/_parser.py index 45822c56ede2..1f0088a76844 100644 --- a/python/tvm/relay/_parser.py +++ b/python/tvm/relay/_parser.py @@ -272,7 +272,7 @@ def _check_existing_typ_expr(self, name, new_expr): def _type_expr_name(self, e): if isinstance(e, adt.Constructor): - return "`{0}` ADT constructor".format(e.belong_to.var.name) + return "`{0}` ADT constructor".format(e.belong_to.name_hint) elif isinstance(e, ty.GlobalTypeVar): if e.kind == ty.Kind.AdtHandle: return "ADT definition" diff --git a/python/tvm/relay/type_functor.py b/python/tvm/relay/type_functor.py index 1331058b37ca..7139ccb4c042 100644 --- a/python/tvm/relay/type_functor.py +++ b/python/tvm/relay/type_functor.py @@ -143,7 +143,7 @@ class TypeMutator(TypeFunctor): and reconstructs the AST. """ def visit_type_var(self, tv): - return TypeVar(tv.var.name, tv.kind) + return TypeVar(tv.name_hint, tv.kind) def visit_incomplete_type(self, it): return IncompleteType(it.kind) @@ -180,7 +180,7 @@ def visit_ref_type(self, rt): return RefType(self.visit(rt.value)) def visit_global_type_var(self, gtv): - return GlobalTypeVar(gtv.var.name, gtv.kind) + return GlobalTypeVar(gtv.name_hint, gtv.kind) def visit_type_call(self, tc): return TypeCall( diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 589de09b0b81..d8dcddde5758 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -69,8 +69,8 @@ class AlphaEqualHandler: } if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false; for (const auto& p : lhsm->type_definitions) { - if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) || - !Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) { + if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) || + !Equal(p.second, rhsm->LookupDef(p.first->name_hint))) { return false; } } @@ -233,11 +233,6 @@ class AlphaEqualHandler: return false; } equal_map_[lhs->type_params[i]] = rhs->type_params[i]; - // set up type parameter equal - if (lhs->type_params[i]->kind == Kind::kShapeVar) { - // map variable - equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var; - } } for (size_t i = 0; i < lhs->arg_types.size(); i++) { if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false; diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 15f5105808aa..459e8b09c94e 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -228,11 +228,11 @@ class RelayHashHandler: hash = Combine(hash, TypeHash(var_node->type_annotation)); } hash_map_[var] = hash; - - const auto* ty_param = var.as(); - if (ty_param && ty_param->kind == Kind::kShapeVar) { - hash_map_[ty_param->var] = hash; - } + // TODO(tqchen) Introduce TypeVarExpr + // const auto* ty_param = var.as(); + // if (ty_param && ty_param->kind == Kind::kShapeVar) { + // hash_map_[ty_param->var] = hash; + // } return hash; } diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 2fa79c7b6322..38f86a50df57 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -55,9 +55,9 @@ Module ModuleNode::make(tvm::Map global_funcs, for (const auto& kv : n->type_definitions) { // set global typevar map - CHECK(n->global_type_var_map_.count(kv.first->var->name_hint) == 0) - << "Duplicate global type definition name " << kv.first->var->name_hint; - n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first); + CHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0) + << "Duplicate global type definition name " << kv.first->name_hint; + n->global_type_var_map_.Set(kv.first->name_hint, kv.first); n->RegisterConstructors(kv.first, kv.second); } @@ -177,7 +177,7 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& // We hash the global type var name to use as a globally unique prefix for tags. // The hash will be used as the most significant byte of the tag, with the index of // the constructor in the less significant bytes - size_t hash = std::hash()(var->var->name_hint); + size_t hash = std::hash()(var->name_hint); int32_t prefix = static_cast(hash & 0xff) << 24; for (size_t i = 0; i < type->constructors.size(); ++i) { type->constructors[i]->tag = prefix | static_cast(i); @@ -197,10 +197,10 @@ void ModuleNode::AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type, this->type_definitions.Set(var, type); if (!update) { // set global type var map - CHECK(global_type_var_map_.count(var->var->name_hint) == 0) - << "Duplicate global type definition name " << var->var->name_hint; + CHECK(global_type_var_map_.count(var->name_hint) == 0) + << "Duplicate global type definition name " << var->name_hint; } - global_type_var_map_.Set(var->var->name_hint, var); + global_type_var_map_.Set(var->name_hint, var); RegisterConstructors(var, type); } @@ -234,7 +234,7 @@ Function ModuleNode::Lookup(const std::string& name) const { TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const { auto it = type_definitions.find(var); CHECK(it != type_definitions.end()) - << "There is no definition of " << var->var->name_hint; + << "There is no definition of " << var->name_hint; return (*it).second; } diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 478469c586ef..992684400c04 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -312,7 +312,7 @@ class PrettyPrinter : val << "-malformed-ir"; return val; } - std::string name = var->var->name_hint; + std::string name = var->name_hint; if (name.length() == 0 || !std::isalpha(name[0])) { name = "t" + name; } @@ -493,7 +493,7 @@ class PrettyPrinter : doc << "["; std::vector type_params; for (const TypeVar& tv : fn->type_params) { - type_params.push_back(Doc(tv->var->name_hint)); + type_params.push_back(Doc(tv->name_hint)); } doc << PrintSep(type_params); doc << "]"; @@ -701,11 +701,11 @@ class PrettyPrinter : } Doc VisitType_(const TypeVarNode* node) final { - return Doc(node->var->name_hint); + return Doc(node->name_hint); } Doc VisitType_(const GlobalTypeVarNode* node) final { - return Doc(node->var->name_hint); + return Doc(node->name_hint); } Doc VisitType_(const TypeCallNode* node) final { diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 70071d0445aa..48f211b4006e 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -65,7 +65,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TypeVar TypeVarNode::make(std::string name, Kind kind) { ObjectPtr n = make_object(); - n->var = tvm::Var(name); + n->name_hint = std::move(name); n->kind = std::move(kind); return TypeVar(n); } @@ -74,19 +74,19 @@ TVM_REGISTER_NODE_TYPE(TypeVarNode); TVM_REGISTER_API("relay._make.TypeVar") .set_body_typed([](std::string name, int kind) { - return TypeVarNode::make(name, static_cast(kind)); - }); + return TypeVarNode::make(name, static_cast(kind)); +}); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const ObjectRef& ref, IRPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "TypeVarNode(" << node->var->name_hint << ", " - << node->kind << ")"; + auto* node = static_cast(ref.get()); + p->stream << "TypeVarNode(" << node->name_hint << ", " + << node->kind << ")"; }); GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) { ObjectPtr n = make_object(); - n->var = tvm::Var(name); + n->name_hint = std::move(name); n->kind = std::move(kind); return GlobalTypeVar(n); } @@ -101,7 +101,7 @@ TVM_REGISTER_API("relay._make.GlobalTypeVar") TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch([](const ObjectRef& ref, IRPrinter* p) { auto* node = static_cast(ref.get()); - p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", " + p->stream << "GlobalTypeVarNode(" << node->name_hint << ", " << node->kind << ")"; }); diff --git a/src/relay/pass/de_duplicate.cc b/src/relay/pass/de_duplicate.cc index 6816cc7d2d83..cf99dc30fc41 100644 --- a/src/relay/pass/de_duplicate.cc +++ b/src/relay/pass/de_duplicate.cc @@ -37,7 +37,7 @@ Expr DeDup(const Expr& e) { public PatternMutator { public: TypeVar Fresh(const TypeVar& tv) { - TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind); + TypeVar ret = TypeVarNode::make(tv->name_hint, tv->kind); type_rename_[tv] = ret; return ret; } diff --git a/src/relay/pass/to_cps.cc b/src/relay/pass/to_cps.cc index 1dfa327d8b0e..96e7f1a78f75 100644 --- a/src/relay/pass/to_cps.cc +++ b/src/relay/pass/to_cps.cc @@ -334,7 +334,7 @@ Function UnCPS(const Function& f) { auto new_ret_type = Type(cont_type->arg_types[0]); std::vector new_type_params; for (const auto& tp : f->type_params) { - new_type_params.push_back(TypeVarNode::make(tp->var->name_hint, tp->kind)); + new_type_params.push_back(TypeVarNode::make(tp->name_hint, tp->kind)); } auto answer_type = new_type_params.back(); new_type_params.pop_back(); diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index d7246da89226..58f724944e63 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -498,15 +498,16 @@ def test_fused_ops(): tvm.testing.assert_allclose(result.asnumpy(), (data + 1) * 2) def test_arange_with_dynamic_shape(): - m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k') - x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32') + # m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k') + m, n, k = relay.Any(), relay.Any(), relay.Any() + x = relay.var('x', shape=(m, n, k), dtype='float32') y0 = relay.shape_of(x) y1 = relay.take(y0, relay.const(0, 'int32')) y2 = relay.op.arange(y1, dtype="int32") y3 = y2 + relay.const(1, dtype="int32") data = np.random.rand(10, 5, 3).astype('float32') mod = relay.module.Module() - mod["main"] = relay.Function([x], y3, type_params=[m, n, k]) + mod["main"] = relay.Function([x], y3) for kind in ["debug", "vm"]: ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") result = ex.evaluate()(data)