Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,13 @@ class TypeVar;
/*! \brief TypeVar container node */
class TypeVarNode : public TypeNode {
public:
/*!
* \brief The variable itself is only meaningful when
* kind is ShapeVar, otherwise, we only use the name.
*/
tvm::Var var;
/*! \brief Name of the variable, it only acts as a hint. */
std::string name_hint;
/*! \brief The kind of type parameter */
Kind kind;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
Expand All @@ -189,16 +186,13 @@ class GlobalTypeVar;
/*! \brief GlobalTypeVar container node */
class GlobalTypeVarNode : public TypeNode {
public:
/*!
* \brief The variable itself is only meaningful when
* kind is ShapeVar; otherwise, we only use the name.
*/
tvm::Var var;
/*! \brief Name of the variable, it only acts as a hint. */
std::string name_hint;
/*! \brief The kind of type parameter */
Kind kind;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("name_hint", &name_hint);
v->Visit("kind", &kind);
v->Visit("span", &span);
}
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _check_existing_typ_expr(self, name, new_expr):

def _type_expr_name(self, e):
if isinstance(e, adt.Constructor):
return "`{0}` ADT constructor".format(e.belong_to.var.name)
return "`{0}` ADT constructor".format(e.belong_to.name_hint)
elif isinstance(e, ty.GlobalTypeVar):
if e.kind == ty.Kind.AdtHandle:
return "ADT definition"
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/type_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ class TypeMutator(TypeFunctor):
and reconstructs the AST.
"""
def visit_type_var(self, tv):
return TypeVar(tv.var.name, tv.kind)
return TypeVar(tv.name_hint, tv.kind)

def visit_incomplete_type(self, it):
return IncompleteType(it.kind)
Expand Down Expand Up @@ -180,7 +180,7 @@ def visit_ref_type(self, rt):
return RefType(self.visit(rt.value))

def visit_global_type_var(self, gtv):
return GlobalTypeVar(gtv.var.name, gtv.kind)
return GlobalTypeVar(gtv.name_hint, gtv.kind)

def visit_type_call(self, tc):
return TypeCall(
Expand Down
9 changes: 2 additions & 7 deletions src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class AlphaEqualHandler:
}
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!rhsm->ContainGlobalTypeVar(p.first->var->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) {
if (!rhsm->ContainGlobalTypeVar(p.first->name_hint) ||
!Equal(p.second, rhsm->LookupDef(p.first->name_hint))) {
return false;
}
}
Expand Down Expand Up @@ -233,11 +233,6 @@ class AlphaEqualHandler:
return false;
}
equal_map_[lhs->type_params[i]] = rhs->type_params[i];
// set up type parameter equal
if (lhs->type_params[i]->kind == Kind::kShapeVar) {
// map variable
equal_map_[lhs->type_params[i]->var] = rhs->type_params[i]->var;
}
}
for (size_t i = 0; i < lhs->arg_types.size(); i++) {
if (!TypeEqual(lhs->arg_types[i], rhs->arg_types[i])) return false;
Expand Down
10 changes: 5 additions & 5 deletions src/relay/ir/hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,11 @@ class RelayHashHandler:
hash = Combine(hash, TypeHash(var_node->type_annotation));
}
hash_map_[var] = hash;

const auto* ty_param = var.as<TypeVarNode>();
if (ty_param && ty_param->kind == Kind::kShapeVar) {
hash_map_[ty_param->var] = hash;
}
// TODO(tqchen) Introduce TypeVarExpr
// const auto* ty_param = var.as<TypeVarNode>();
// if (ty_param && ty_param->kind == Kind::kShapeVar) {
// hash_map_[ty_param->var] = hash;
// }
return hash;
}

Expand Down
16 changes: 8 additions & 8 deletions src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,

for (const auto& kv : n->type_definitions) {
// set global typevar map
CHECK(n->global_type_var_map_.count(kv.first->var->name_hint) == 0)
<< "Duplicate global type definition name " << kv.first->var->name_hint;
n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first);
CHECK(n->global_type_var_map_.count(kv.first->name_hint) == 0)
<< "Duplicate global type definition name " << kv.first->name_hint;
n->global_type_var_map_.Set(kv.first->name_hint, kv.first);
n->RegisterConstructors(kv.first, kv.second);
}

Expand Down Expand Up @@ -177,7 +177,7 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData&
// We hash the global type var name to use as a globally unique prefix for tags.
// The hash will be used as the most significant byte of the tag, with the index of
// the constructor in the less significant bytes
size_t hash = std::hash<std::string>()(var->var->name_hint);
size_t hash = std::hash<std::string>()(var->name_hint);
int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
for (size_t i = 0; i < type->constructors.size(); ++i) {
type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
Expand All @@ -197,10 +197,10 @@ void ModuleNode::AddDefUnchecked(const GlobalTypeVar& var, const TypeData& type,
this->type_definitions.Set(var, type);
if (!update) {
// set global type var map
CHECK(global_type_var_map_.count(var->var->name_hint) == 0)
<< "Duplicate global type definition name " << var->var->name_hint;
CHECK(global_type_var_map_.count(var->name_hint) == 0)
<< "Duplicate global type definition name " << var->name_hint;
}
global_type_var_map_.Set(var->var->name_hint, var);
global_type_var_map_.Set(var->name_hint, var);
RegisterConstructors(var, type);
}

Expand Down Expand Up @@ -234,7 +234,7 @@ Function ModuleNode::Lookup(const std::string& name) const {
TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
CHECK(it != type_definitions.end())
<< "There is no definition of " << var->var->name_hint;
<< "There is no definition of " << var->name_hint;
return (*it).second;
}

Expand Down
8 changes: 4 additions & 4 deletions src/relay/ir/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ class PrettyPrinter :
val << "-malformed-ir";
return val;
}
std::string name = var->var->name_hint;
std::string name = var->name_hint;
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "t" + name;
}
Expand Down Expand Up @@ -493,7 +493,7 @@ class PrettyPrinter :
doc << "[";
std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) {
type_params.push_back(Doc(tv->var->name_hint));
type_params.push_back(Doc(tv->name_hint));
}
doc << PrintSep(type_params);
doc << "]";
Expand Down Expand Up @@ -701,11 +701,11 @@ class PrettyPrinter :
}

Doc VisitType_(const TypeVarNode* node) final {
return Doc(node->var->name_hint);
return Doc(node->name_hint);
}

Doc VisitType_(const GlobalTypeVarNode* node) final {
return Doc(node->var->name_hint);
return Doc(node->name_hint);
}

Doc VisitType_(const TypeCallNode* node) final {
Expand Down
16 changes: 8 additions & 8 deletions src/relay/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)

TypeVar TypeVarNode::make(std::string name, Kind kind) {
ObjectPtr<TypeVarNode> n = make_object<TypeVarNode>();
n->var = tvm::Var(name);
n->name_hint = std::move(name);
n->kind = std::move(kind);
return TypeVar(n);
}
Expand All @@ -74,19 +74,19 @@ TVM_REGISTER_NODE_TYPE(TypeVarNode);

TVM_REGISTER_API("relay._make.TypeVar")
.set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
return TypeVarNode::make(name, static_cast<Kind>(kind));
});
return TypeVarNode::make(name, static_cast<Kind>(kind));
});

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<TypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const TypeVarNode*>(ref.get());
p->stream << "TypeVarNode(" << node->var->name_hint << ", "
<< node->kind << ")";
auto* node = static_cast<const TypeVarNode*>(ref.get());
p->stream << "TypeVarNode(" << node->name_hint << ", "
<< node->kind << ")";
});

GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) {
ObjectPtr<GlobalTypeVarNode> n = make_object<GlobalTypeVarNode>();
n->var = tvm::Var(name);
n->name_hint = std::move(name);
n->kind = std::move(kind);
return GlobalTypeVar(n);
}
Expand All @@ -101,7 +101,7 @@ TVM_REGISTER_API("relay._make.GlobalTypeVar")
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const ObjectRef& ref, IRPrinter* p) {
auto* node = static_cast<const GlobalTypeVarNode*>(ref.get());
p->stream << "GlobalTypeVarNode(" << node->var->name_hint << ", "
p->stream << "GlobalTypeVarNode(" << node->name_hint << ", "
<< node->kind << ")";
});

Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/de_duplicate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Expr DeDup(const Expr& e) {
public PatternMutator {
public:
TypeVar Fresh(const TypeVar& tv) {
TypeVar ret = TypeVarNode::make(tv->var->name_hint, tv->kind);
TypeVar ret = TypeVarNode::make(tv->name_hint, tv->kind);
type_rename_[tv] = ret;
return ret;
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/to_cps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ Function UnCPS(const Function& f) {
auto new_ret_type = Type(cont_type->arg_types[0]);
std::vector<TypeVar> new_type_params;
for (const auto& tp : f->type_params) {
new_type_params.push_back(TypeVarNode::make(tp->var->name_hint, tp->kind));
new_type_params.push_back(TypeVarNode::make(tp->name_hint, tp->kind));
}
auto answer_type = new_type_params.back();
new_type_params.pop_back();
Expand Down
7 changes: 4 additions & 3 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,15 +498,16 @@ def test_fused_ops():
tvm.testing.assert_allclose(result.asnumpy(), (data + 1) * 2)

def test_arange_with_dynamic_shape():
m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
x = relay.var('x', shape=(m.var, n.var, k.var), dtype='float32')
# m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k')
m, n, k = relay.Any(), relay.Any(), relay.Any()
x = relay.var('x', shape=(m, n, k), dtype='float32')
y0 = relay.shape_of(x)
y1 = relay.take(y0, relay.const(0, 'int32'))
y2 = relay.op.arange(y1, dtype="int32")
y3 = y2 + relay.const(1, dtype="int32")
data = np.random.rand(10, 5, 3).astype('float32')
mod = relay.module.Module()
mod["main"] = relay.Function([x], y3, type_params=[m, n, k])
mod["main"] = relay.Function([x], y3)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data)
Expand Down