Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ class GlobalVarNode : public RelayExprNode {

void VisitAttrs(AttrVisitor* v) {
v->Visit("name_hint", &name_hint);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ class MatchNode : public ExprNode {
v->Visit("data", &data);
v->Visit("clauses", &clauses);
v->Visit("complete", &complete);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class TupleNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -196,6 +197,7 @@ class VarNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("vid", &vid);
v->Visit("type_annotation", &type_annotation);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -319,6 +321,7 @@ class CallNode : public ExprNode {
v->Visit("args", &args);
v->Visit("attrs", &attrs);
v->Visit("type_args", &type_args);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -425,6 +428,7 @@ class LetNode : public ExprNode {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("body", &body);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -516,6 +520,7 @@ class IfNode : public ExprNode {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
v->Visit("false_branch", &false_branch);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -589,6 +594,7 @@ class TupleGetItemNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("tuple_value", &tuple);
v->Visit("index", &index);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -652,6 +658,7 @@ class RefCreateNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("value", &value);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -713,6 +720,7 @@ class RefReadNode : public ExprNode {

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down Expand Up @@ -776,6 +784,7 @@ class RefWriteNode : public ExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("ref", &ref);
v->Visit("value", &value);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class FunctionNode : public BaseFuncNode {
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("attrs", &attrs);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
Expand Down
6 changes: 4 additions & 2 deletions src/node/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ class FieldDependencyFinder : public AttrVisitor {
std::string GetValue(const char* key) const {
auto it = jnode_->attrs.find(key);
if (it == jnode_->attrs.end()) {
LOG(FATAL) << "JSONReader: cannot find field " << key;
// If we encounter a field that hasn't been set, initialize it to null.
return "0";
}
return it->second;
}
Expand Down Expand Up @@ -372,7 +373,8 @@ class JSONAttrSetter : public AttrVisitor {
std::string GetValue(const char* key) const {
auto it = jnode_->attrs.find(key);
if (it == jnode_->attrs.end()) {
LOG(FATAL) << "JSONReader: cannot find field " << key;
// If we encounter a field that hasn't been set, initialize it to null.
return "0";
}
return it->second;
}
Expand Down
36 changes: 36 additions & 0 deletions tests/python/relay/test_json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,42 @@ def test_str_map():
assert bool(x["z"] == 2)


def test_default_fields():
# Node with all fields set
nodes = [
{"type_key": ""},
{
"type_key": "relay.GlobalVar",
"attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0", "virtual_device_": "0"},
},
]
data = {
"root": 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
tvar = tvm.ir.load_json(json.dumps(data))
assert isinstance(tvar, tvm.ir.GlobalVar)
# Construct node without virtual_device_ field
nodes = [
{"type_key": ""},
{
"type_key": "relay.GlobalVar",
"attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"},
},
]
data = {
"root": 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
tvar_default = tvm.ir.load_json(json.dumps(data))
assert isinstance(tvar_default, tvm.ir.GlobalVar)
assert not tvar_default.virtual_device_


if __name__ == "__main__":
test_op()
test_type_var()
Expand Down