From 4f8c0f9c7f71926ade1b3beab494be26d7971572 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Thu, 6 Jan 2022 15:57:57 -0800 Subject: [PATCH] Add default to serialization --- include/tvm/ir/expr.h | 1 + include/tvm/relay/adt.h | 1 + include/tvm/relay/expr.h | 9 +++++++ include/tvm/relay/function.h | 1 + src/node/serialization.cc | 6 +++-- tests/python/relay/test_json_compact.py | 36 +++++++++++++++++++++++++ 6 files changed, 52 insertions(+), 2 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 8937bb7b1016..2cfc8467a1cb 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -221,6 +221,7 @@ class GlobalVarNode : public RelayExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("name_hint", &name_hint); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index b1d4d5975cb8..31dec2204146 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -299,6 +299,7 @@ class MatchNode : public ExprNode { v->Visit("data", &data); v->Visit("clauses", &clauses); v->Visit("complete", &complete); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 04dd9223719e..dcb7838a1b72 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -108,6 +108,7 @@ class TupleNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -196,6 +197,7 @@ class VarNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("vid", &vid); v->Visit("type_annotation", &type_annotation); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -319,6 +321,7 @@ class CallNode : public ExprNode { v->Visit("args", &args); v->Visit("attrs", &attrs); v->Visit("type_args", &type_args); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -425,6 +428,7 @@ class LetNode : public ExprNode { v->Visit("var", &var); v->Visit("value", &value); v->Visit("body", &body); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -516,6 +520,7 @@ class IfNode : public ExprNode { v->Visit("cond", &cond); v->Visit("true_branch", &true_branch); v->Visit("false_branch", &false_branch); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -589,6 +594,7 @@ class TupleGetItemNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tuple_value", &tuple); v->Visit("index", &index); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -652,6 +658,7 @@ class RefCreateNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -713,6 +720,7 @@ class RefReadNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ref", &ref); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -776,6 +784,7 @@ class RefWriteNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ref", &ref); v->Visit("value", &value); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index d9bf7acaa037..5869f878aa85 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -64,6 +64,7 @@ class FunctionNode : public BaseFuncNode { v->Visit("ret_type", &ret_type); v->Visit("type_params", &type_params); v->Visit("attrs", &attrs); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 09eb02e10bfa..8134c895e389 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -315,7 +315,8 @@ class FieldDependencyFinder : public AttrVisitor { std::string GetValue(const char* key) const { auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader: cannot find field " << key; + // If we encounter a field that hasn't been set, initialize it to null. + return "0"; } return it->second; } @@ -372,7 +373,8 @@ class JSONAttrSetter : public AttrVisitor { std::string GetValue(const char* key) const { auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader: cannot find field " << key; + // If we encounter a field that hasn't been set, initialize it to null. + return "0"; } return it->second; } diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 65efc306a347..3ca738ac8c8e 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -206,6 +206,42 @@ def test_str_map(): assert bool(x["z"] == 2) +def test_default_fields(): + # Node with all fields set + nodes = [ + {"type_key": ""}, + { + "type_key": "relay.GlobalVar", + "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0", "virtual_device_": "0"}, + }, + ] + data = { + "root": 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + tvar = tvm.ir.load_json(json.dumps(data)) + assert isinstance(tvar, tvm.ir.GlobalVar) + # Construct node without virtual_device_ field + nodes = [ + {"type_key": ""}, + { + "type_key": "relay.GlobalVar", + "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"}, + }, + ] + data = { + "root": 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + tvar_default = tvm.ir.load_json(json.dumps(data)) + assert isinstance(tvar_default, tvm.ir.GlobalVar) + assert not tvar_default.virtual_device_ + + if __name__ == "__main__": test_op() test_type_var()