From 4f8c0f9c7f71926ade1b3beab494be26d7971572 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Thu, 6 Jan 2022 15:57:57 -0800 Subject: [PATCH 01/16] Add default to serialization --- include/tvm/ir/expr.h | 1 + include/tvm/relay/adt.h | 1 + include/tvm/relay/expr.h | 9 +++++++ include/tvm/relay/function.h | 1 + src/node/serialization.cc | 6 +++-- tests/python/relay/test_json_compact.py | 36 +++++++++++++++++++++++++ 6 files changed, 52 insertions(+), 2 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 8937bb7b1016..2cfc8467a1cb 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -221,6 +221,7 @@ class GlobalVarNode : public RelayExprNode { void VisitAttrs(AttrVisitor* v) { v->Visit("name_hint", &name_hint); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index b1d4d5975cb8..31dec2204146 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -299,6 +299,7 @@ class MatchNode : public ExprNode { v->Visit("data", &data); v->Visit("clauses", &clauses); v->Visit("complete", &complete); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 04dd9223719e..dcb7838a1b72 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -108,6 +108,7 @@ class TupleNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("fields", &fields); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -196,6 +197,7 @@ class VarNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("vid", &vid); v->Visit("type_annotation", &type_annotation); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -319,6 +321,7 @@ class CallNode : public ExprNode { v->Visit("args", &args); v->Visit("attrs", &attrs); v->Visit("type_args", &type_args); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -425,6 +428,7 @@ class LetNode : public ExprNode { v->Visit("var", &var); v->Visit("value", &value); v->Visit("body", &body); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -516,6 +520,7 @@ class IfNode : public ExprNode { v->Visit("cond", &cond); v->Visit("true_branch", &true_branch); v->Visit("false_branch", &false_branch); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -589,6 +594,7 @@ class TupleGetItemNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("tuple_value", &tuple); v->Visit("index", &index); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -652,6 +658,7 @@ class RefCreateNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("value", &value); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -713,6 +720,7 @@ class RefReadNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ref", &ref); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } @@ -776,6 +784,7 @@ class RefWriteNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("ref", &ref); v->Visit("value", &value); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index d9bf7acaa037..5869f878aa85 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -64,6 +64,7 @@ class FunctionNode : public BaseFuncNode { v->Visit("ret_type", &ret_type); v->Visit("type_params", &type_params); v->Visit("attrs", &attrs); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 09eb02e10bfa..8134c895e389 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -315,7 +315,8 @@ class FieldDependencyFinder : public AttrVisitor { std::string GetValue(const char* key) const { auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader: cannot find field " << key; + // If we encounter a field that hasn't been set, initialize it to null. + return "0"; } return it->second; } @@ -372,7 +373,8 @@ class JSONAttrSetter : public AttrVisitor { std::string GetValue(const char* key) const { auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader: cannot find field " << key; + // If we encounter a field that hasn't been set, initialize it to null. + return "0"; } return it->second; } diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 65efc306a347..3ca738ac8c8e 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -206,6 +206,42 @@ def test_str_map(): assert bool(x["z"] == 2) +def test_default_fields(): + # Node with all fields set + nodes = [ + {"type_key": ""}, + { + "type_key": "relay.GlobalVar", + "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0", "virtual_device_": "0"}, + }, + ] + data = { + "root": 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + tvar = tvm.ir.load_json(json.dumps(data)) + assert isinstance(tvar, tvm.ir.GlobalVar) + # Construct node without virtual_device_ field + nodes = [ + {"type_key": ""}, + { + "type_key": "relay.GlobalVar", + "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"}, + }, + ] + data = { + "root": 1, + "nodes": nodes, + "attrs": {"tvm_version": "0.6.0"}, + "b64ndarrays": [], + } + tvar_default = tvm.ir.load_json(json.dumps(data)) + assert isinstance(tvar_default, tvm.ir.GlobalVar) + assert not tvar_default.virtual_device_ + + if __name__ == "__main__": test_op() test_type_var() From 993100abdc8ba06de7ab02780afcd351b9655ff9 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 13:40:26 -0800 Subject: [PATCH 02/16] revert changes in serialization.cc --- src/node/serialization.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 8134c895e389..09eb02e10bfa 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -315,8 +315,7 @@ class FieldDependencyFinder : public AttrVisitor { std::string GetValue(const char* key) const { auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - // If we encounter a field that hasn't been set, initialize it to null. - return "0"; + LOG(FATAL) << "JSONReader: cannot find field " << key; } return it->second; } @@ -373,8 +372,7 @@ class JSONAttrSetter : public AttrVisitor { std::string GetValue(const char* key) const { auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - // If we encounter a field that hasn't been set, initialize it to null. - return "0"; + LOG(FATAL) << "JSONReader: cannot find field " << key; } return it->second; } From 5edb222a3e125e2525ef800d51546262326c69b7 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 14:19:58 -0800 Subject: [PATCH 03/16] update 0.6 converter --- python/tvm/ir/json_compact.py | 34 +++++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index a22d7d3ce108..6d2d000f4bcc 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -56,8 +56,10 @@ def _updater(data): return _updater +def create_updater_08_to_09(): + pass -def create_updater_06_to_07(): +def create_updater_06_to_09(): """Create an update to upgrade json from v0.6 to v0.7 Returns @@ -91,6 +93,15 @@ def _convert(item, _): return _convert + def _initialize_virtual_device(item, _): + print(item) + item["attrs"]["virtual_device_"] = "0" + return item + + def _initialize_module_attributes(item, _): + item["attrs"]["attrs"] = "0" + return item + def _update_global_key(item, _): if "global_key" in item: item["repr_str"] = item["global_key"] @@ -128,17 +139,28 @@ def _convert(item, nodes): "relay.TypeRelation": _rename("TypeRelation"), "relay.TypeCall": _rename("TypeCall"), "relay.Constructor": [_update_from_std_str("name_hint")], - "relay.Module": _rename("IRModule"), + "relay.Module": [_rename("IRModule"), _initialize_module_attributes], "relay.SourceName": _rename("SourceName"), "relay.Span": _rename("Span"), - "relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint")], - "GlobalVar": _update_from_std_str("name_hint"), + "relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint"), _initialize_virtual_device], + "GlobalVar": [_update_from_std_str("name_hint"), _initialize_virtual_device], + "relay.Var": _initialize_virtual_device, + "relay.Function": _initialize_virtual_device, + "relay.Tuple": _initialize_virtual_device, + "relay.Call": _initialize_virtual_device, + "relay.Let": _initialize_virtual_device, + "relay.If": _initialize_virtual_device, + "relay.TupleGetItem": _initialize_virtual_device, + "relay.RefCreate": _initialize_virtual_device, + "relay.RefRead": _initialize_virtual_device, + "relay.RefWrite": _initialize_virtual_device, "relay.Pass": _rename("transform.Pass"), "relay.PassInfo": _rename("transform.PassInfo"), "relay.PassContext": _rename("transform.PassContext"), "relay.ModulePass": _rename("transform.ModulePass"), "relay.Sequential": _rename("transform.Sequential"), "StrMap": _rename("Map"), + # TIR "Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")], "SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")], @@ -207,7 +229,9 @@ def upgrade_json(json_str): data = json.loads(json_str) from_version = data["attrs"]["tvm_version"] if from_version.startswith("0.6"): - data = create_updater_06_to_07()(data) + data = create_updater_06_to_09()(data) + elif from_version.startswith("0.8"): + data = create_updater_08_to_09()(data) else: raise ValueError("Cannot update from version %s" % from_version) return json.dumps(data, indent=2) From 32f75cb060a2b8443a03b32b43485ef8877b36be Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 16:55:36 -0800 Subject: [PATCH 04/16] json updater working, except for cycles --- python/tvm/ir/base.py | 6 +- python/tvm/ir/json_compact.py | 164 +++++++++++++++--------- src/node/serialization.cc | 24 +++- tests/python/relay/test_json_compact.py | 68 +++++++--- 4 files changed, 177 insertions(+), 85 deletions(-) diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 00514b472d67..001f4d764c4e 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -135,9 +135,13 @@ def load_json(json_str): """ try: - return tvm.runtime._ffi_node_api.LoadJSON(json_str) + loaded = tvm.runtime._ffi_node_api.LoadJSON(json_str) + print("LOADED COMPLETE") + return loaded except tvm.error.TVMError: + print("Upgrading Json") json_str = json_compact.upgrade_json(json_str) + print("Loading again") return tvm.runtime._ffi_node_api.LoadJSON(json_str) diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 6d2d000f4bcc..0ef6ced23b87 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -56,10 +56,57 @@ def _updater(data): return _updater + def create_updater_08_to_09(): - pass + """ + Create an update to upgrade json from v0.8 to v0.9 + + Returns + ------- + fupdater : function + The updater function + """ + + def _initialize_virtual_device(item, _): + print("Item before: ", item) + item["attrs"]["virtual_device_"] = "0" + print("Item after: ", item) + return item + + node_map = { + # Base IR + "GlobalVar": _initialize_virtual_device, + "relay.Var": _initialize_virtual_device, + "relay.Function": _initialize_virtual_device, + "relay.Tuple": _initialize_virtual_device, + "relay.Call": _initialize_virtual_device, + "relay.Let": _initialize_virtual_device, + "relay.If": _initialize_virtual_device, + "relay.TupleGetItem": _initialize_virtual_device, + "relay.RefCreate": _initialize_virtual_device, + "relay.RefRead": _initialize_virtual_device, + "relay.RefWrite": _initialize_virtual_device, + "relay.Match": _initialize_virtual_device, + } + + return create_updater(node_map, "0.8", "0.9") -def create_updater_06_to_09(): + +def create_updater_07_to_08(): + """Create an update to upgrade json from v0.7 to v0.8""" + + def _initialize_module_attributes(item, _): + assert item["type_key"] == "IRModule", "Only initialize the attributes for IRModules" + print("Module before: ", item) + item["attrs"]["attrs"] = "0" + print("Module after:", item) + return item + + node_map = {"IRModule": _initialize_module_attributes} + return create_updater(node_map, "0.7", "0.8") + + +def create_updater_06_to_07(): """Create an update to upgrade json from v0.6 to v0.7 Returns @@ -93,15 +140,6 @@ def _convert(item, _): return _convert - def _initialize_virtual_device(item, _): - print(item) - item["attrs"]["virtual_device_"] = "0" - return item - - def _initialize_module_attributes(item, _): - item["attrs"]["attrs"] = "0" - return item - def _update_global_key(item, _): if "global_key" in item: item["repr_str"] = item["global_key"] @@ -138,70 +176,59 @@ def _convert(item, nodes): "relay.IncompleteType": _rename("IncompleteType"), "relay.TypeRelation": _rename("TypeRelation"), "relay.TypeCall": _rename("TypeCall"), - "relay.Constructor": [_update_from_std_str("name_hint")], - "relay.Module": [_rename("IRModule"), _initialize_module_attributes], + "relay.Constructor": _update_from_std_str("name_hint"), + "relay.Module": _rename("IRModule"), "relay.SourceName": _rename("SourceName"), "relay.Span": _rename("Span"), - "relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint"), _initialize_virtual_device], - "GlobalVar": [_update_from_std_str("name_hint"), _initialize_virtual_device], - "relay.Var": _initialize_virtual_device, - "relay.Function": _initialize_virtual_device, - "relay.Tuple": _initialize_virtual_device, - "relay.Call": _initialize_virtual_device, - "relay.Let": _initialize_virtual_device, - "relay.If": _initialize_virtual_device, - "relay.TupleGetItem": _initialize_virtual_device, - "relay.RefCreate": _initialize_virtual_device, - "relay.RefRead": _initialize_virtual_device, - "relay.RefWrite": _initialize_virtual_device, + "relay.GlobalVar": [_rename("GlobalVar"), _update_from_std_str("name_hint")], + "GlobalVar": _update_from_std_str("name_hint"), "relay.Pass": _rename("transform.Pass"), "relay.PassInfo": _rename("transform.PassInfo"), "relay.PassContext": _rename("transform.PassContext"), "relay.ModulePass": _rename("transform.ModulePass"), "relay.Sequential": _rename("transform.Sequential"), "StrMap": _rename("Map"), - # TIR "Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")], "SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")], "StringImm": [_rename("tir.StringImm"), _update_from_std_str("value")], - "Cast": [_rename("tir.Cast")], - "Add": [_rename("tir.Add")], - "Sub": [_rename("tir.Sub")], - "Mul": [_rename("tir.Mul")], - "Div": [_rename("tir.Div")], - "Mod": [_rename("tir.Mod")], - "FloorDiv": [_rename("tir.FloorDiv")], - "FloorMod": [_rename("tir.FloorMod")], - "Min": [_rename("tir.Min")], - "Max": [_rename("tir.Max")], - "EQ": [_rename("tir.EQ")], - "NE": [_rename("tir.NE")], - "LT": [_rename("tir.LT")], - "LE": [_rename("tir.LE")], - "GT": [_rename("tir.GT")], - "GE": [_rename("tir.GE")], - "And": [_rename("tir.And")], - "Or": [_rename("tir.Or")], - "Not": [_rename("tir.Not")], - "Select": [_rename("tir.Select")], - "Load": [_rename("tir.Load")], - "BufferLoad": [_rename("tir.BufferLoad")], - "Ramp": [_rename("tir.Ramp")], - "Broadcast": [_rename("tir.Broadcast")], - "Shuffle": [_rename("tir.Shuffle")], + "Cast": _rename("tir.Cast"), + "Add": _rename("tir.Add"), + "Sub": _rename("tir.Sub"), + "Mul": _rename("tir.Mul"), + "Div": _rename("tir.Div"), + "Mod": _rename("tir.Mod"), + "FloorDiv": _rename("tir.FloorDiv"), + "FloorMod": _rename("tir.FloorMod"), + "Min": _rename("tir.Min"), + "Max": _rename("tir.Max"), + "EQ": _rename("tir.EQ"), + "NE": _rename("tir.NE"), + "LT": _rename("tir.LT"), + "LE": _rename("tir.LE"), + "GT": _rename("tir.GT"), + "GE": _rename("tir.GE"), + "And": _rename("tir.And"), + "Or": _rename("tir.Or"), + "Not": _rename("tir.Not"), + "Select": _rename("tir.Select"), + "Load": _rename("tir.Load"), + "BufferLoad": _rename("tir.BufferLoad"), + "Ramp": _rename("tir.Ramp"), + "Broadcast": _rename("tir.Broadcast"), + "Shuffle": _rename("tir.Shuffle"), "Call": [_rename("tir.Call"), _update_from_std_str("name")], - "Let": [_rename("tir.Let")], - "Any": [_rename("tir.Any")], - "LetStmt": [_rename("tir.LetStmt")], - "AssertStmt": [_rename("tir.AssertStmt")], - "Store": [_rename("tir.Store")], - "BufferStore": [_rename("tir.BufferStore")], - "BufferRealize": [_rename("tir.BufferRealize")], - "Allocate": [_rename("tir.Allocate")], - "IfThenElse": [_rename("tir.IfThenElse")], - "Evaluate": [_rename("tir.Evaluate")], - "Prefetch": [_rename("tir.Prefetch")], + "Let": _rename("tir.Let"), + "Any": _rename("tir.Any"), + "LetStmt": _rename("tir.LetStmt"), + "AssertStmt": _rename("tir.AssertStmt"), + "Store": _rename("tir.Store"), + "BufferStore": _rename("tir.BufferStore"), + "BufferRealize": _rename("tir.BufferRealize"), + "Allocate": _rename("tir.Allocate"), + "IfThenElse": _rename("tir.IfThenElse"), + "Evaluate": _rename("tir.Evaluate"), + "Prefetch": _rename("tir.Prefetch"), "AttrStmt": [_rename("tir.AttrStmt"), _update_from_std_str("attr_key")], "Layout": [_rename("tir.Layout"), _update_from_std_str("name")], "Buffer": [ @@ -227,10 +254,21 @@ def upgrade_json(json_str): The updated version. """ data = json.loads(json_str) + print("Completed loading") from_version = data["attrs"]["tvm_version"] + if from_version.startswith("0.6"): - data = create_updater_06_to_09()(data) + print("From 0.6") + data = create_updater_08_to_09()(create_updater_07_to_08()(create_updater_06_to_07()(data))) + elif from_version.startswith("0.7"): + print("From 0.7") + data1 = create_updater_07_to_08()(data) + print("First updater done") + data2 = create_updater_08_to_09()(data1) + print("2nd updater done") + data = data2 elif from_version.startswith("0.8"): + print("From 0.8") data = create_updater_08_to_09()(data) else: raise ValueError("Cannot update from version %s" % from_version) diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 09eb02e10bfa..f909efaec8c8 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -313,14 +313,24 @@ class FieldDependencyFinder : public AttrVisitor { ReflectionVTable* reflection_ = ReflectionVTable::Global(); std::string GetValue(const char* key) const { + std::cout << "Dependency finder" << std::endl; + std::cout << "Key: " << key << std::endl; + std::cout << "All keys: ["; + + for (auto kv : jnode_->attrs) { + std::cout << kv.first << ", "; + } + std::cout << "]" << std::endl; + auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader: cannot find field " << key; + LOG(FATAL) << "JSONReader (DependencyFinder): cannot find field " << key; } return it->second; } template void ParseValue(const char* key, T* value) const { + std::cout << "ParseValue for " << key << std::endl; std::istringstream is(GetValue(key)); is >> *value; if (is.fail()) { @@ -337,6 +347,7 @@ class FieldDependencyFinder : public AttrVisitor { void Visit(const char* key, DataType* value) final {} void Visit(const char* key, runtime::NDArray* value) final {} void Visit(const char* key, ObjectRef* value) final { + std::cout << "Object: " << PrettyPrint(*value) << std::endl; size_t index; ParseValue(key, &index); jnode_->fields.push_back(index); @@ -370,9 +381,16 @@ class JSONAttrSetter : public AttrVisitor { ReflectionVTable* reflection_ = ReflectionVTable::Global(); std::string GetValue(const char* key) const { + std::cout << "Key: " << key << std::endl; + std::cout << "All keys: ["; + + for (auto kv : jnode_->attrs) { + std::cout << kv.first << ", "; + } + std::cout << "]" << std::endl; auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader: cannot find field " << key; + LOG(FATAL) << "JSONReader (AttrSetter): cannot find field " << key; } return it->second; } @@ -557,7 +575,7 @@ struct JSONGraph { } } } - ICHECK_EQ(topo_order.size(), n_nodes) << "Cyclic reference detected in JSON file"; + // ICHECK_EQ(topo_order.size(), n_nodes) << "Cyclic reference detected in JSON file"; std::reverse(std::begin(topo_order), std::end(topo_order)); return topo_order; } diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 3ca738ac8c8e..2d8eab6e3b3b 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -21,6 +21,9 @@ import json +# 0.6 BACKWARDS COMPATIBILITY TESTS + + def test_type_var(): # type var in 0.6 nodes = [ @@ -206,41 +209,70 @@ def test_str_map(): assert bool(x["z"] == 2) -def test_default_fields(): - # Node with all fields set +# 0.7 BACKWARDS COMPATIBILITY TESTS + + +def test_irmodule_attributes(): nodes = [ - {"type_key": ""}, { - "type_key": "relay.GlobalVar", - "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0", "virtual_device_": "0"}, - }, + "type_key": "IRModule", + "attrs": { + "functions": "0", + "global_type_var_map_": "0", + "global_var_map_": "0", + "source_map": "0", + "type_definitions": "0", + }, + } ] data = { "root": 1, "nodes": nodes, - "attrs": {"tvm_version": "0.6.0"}, + "attrs": {"tvm_version": "0.7.0"}, "b64ndarrays": [], } - tvar = tvm.ir.load_json(json.dumps(data)) - assert isinstance(tvar, tvm.ir.GlobalVar) - # Construct node without virtual_device_ field + mod = tvm.ir.load_json(json.dumps(data)) + assert isinstance(mod, tvm.ir.IRModule) + # IRModule attributes should defualt to null + assert not mod.attrs + + +# 0.8 BACKWARDS COMPATIBILITY TESTS + +# Does this break with functions? Yes. Seems bad. Probably should remove json dep checker? +def test_func_cycle(): nodes = [ - {"type_key": ""}, { - "type_key": "relay.GlobalVar", - "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"}, - }, + "type_key": "relay.Function", + "attrs": { + "_checked_type_": "0", + "attrs": "0", + "body": "0", + "params": "0", + "ret_type": "0", + "span": "0", + "type_params": "0", + }, + } ] data = { "root": 1, "nodes": nodes, - "attrs": {"tvm_version": "0.6.0"}, + "attrs": {"tvm_version": "0.8.0"}, "b64ndarrays": [], } - tvar_default = tvm.ir.load_json(json.dumps(data)) - assert isinstance(tvar_default, tvm.ir.GlobalVar) - assert not tvar_default.virtual_device_ + dump = json.dumps(data) + print("Done dumping") + func = tvm.ir.load_json(dump) + assert isinstance(func, relay.Function) + assert not func.virtual_device_ + + +# add module attributes and virtual device test + +# BACKWARD COMPAT WITH 0.8 TESTS +# add test module attrs and test virtual device if __name__ == "__main__": test_op() From 72f2b8524c2f98d5f5b4e526649b436363c935ed Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 16:58:26 -0800 Subject: [PATCH 05/16] clean up code --- python/tvm/ir/base.py | 6 +----- python/tvm/ir/json_compact.py | 14 +------------- src/node/serialization.cc | 24 +++--------------------- tests/python/relay/test_json_compact.py | 1 - 4 files changed, 5 insertions(+), 40 deletions(-) diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index 001f4d764c4e..00514b472d67 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -135,13 +135,9 @@ def load_json(json_str): """ try: - loaded = tvm.runtime._ffi_node_api.LoadJSON(json_str) - print("LOADED COMPLETE") - return loaded + return tvm.runtime._ffi_node_api.LoadJSON(json_str) except tvm.error.TVMError: - print("Upgrading Json") json_str = json_compact.upgrade_json(json_str) - print("Loading again") return tvm.runtime._ffi_node_api.LoadJSON(json_str) diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 0ef6ced23b87..2be329057dfa 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -68,9 +68,7 @@ def create_updater_08_to_09(): """ def _initialize_virtual_device(item, _): - print("Item before: ", item) item["attrs"]["virtual_device_"] = "0" - print("Item after: ", item) return item node_map = { @@ -97,9 +95,7 @@ def create_updater_07_to_08(): def _initialize_module_attributes(item, _): assert item["type_key"] == "IRModule", "Only initialize the attributes for IRModules" - print("Module before: ", item) item["attrs"]["attrs"] = "0" - print("Module after:", item) return item node_map = {"IRModule": _initialize_module_attributes} @@ -254,21 +250,13 @@ def upgrade_json(json_str): The updated version. """ data = json.loads(json_str) - print("Completed loading") from_version = data["attrs"]["tvm_version"] if from_version.startswith("0.6"): - print("From 0.6") data = create_updater_08_to_09()(create_updater_07_to_08()(create_updater_06_to_07()(data))) elif from_version.startswith("0.7"): - print("From 0.7") - data1 = create_updater_07_to_08()(data) - print("First updater done") - data2 = create_updater_08_to_09()(data1) - print("2nd updater done") - data = data2 + data = create_updater_08_to_09()(create_updater_07_to_08()(data)) elif from_version.startswith("0.8"): - print("From 0.8") data = create_updater_08_to_09()(data) else: raise ValueError("Cannot update from version %s" % from_version) diff --git a/src/node/serialization.cc b/src/node/serialization.cc index f909efaec8c8..09eb02e10bfa 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -313,24 +313,14 @@ class FieldDependencyFinder : public AttrVisitor { ReflectionVTable* reflection_ = ReflectionVTable::Global(); std::string GetValue(const char* key) const { - std::cout << "Dependency finder" << std::endl; - std::cout << "Key: " << key << std::endl; - std::cout << "All keys: ["; - - for (auto kv : jnode_->attrs) { - std::cout << kv.first << ", "; - } - std::cout << "]" << std::endl; - auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader (DependencyFinder): cannot find field " << key; + LOG(FATAL) << "JSONReader: cannot find field " << key; } return it->second; } template void ParseValue(const char* key, T* value) const { - std::cout << "ParseValue for " << key << std::endl; std::istringstream is(GetValue(key)); is >> *value; if (is.fail()) { @@ -347,7 +337,6 @@ class FieldDependencyFinder : public AttrVisitor { void Visit(const char* key, DataType* value) final {} void Visit(const char* key, runtime::NDArray* value) final {} void Visit(const char* key, ObjectRef* value) final { - std::cout << "Object: " << PrettyPrint(*value) << std::endl; size_t index; ParseValue(key, &index); jnode_->fields.push_back(index); @@ -381,16 +370,9 @@ class JSONAttrSetter : public AttrVisitor { ReflectionVTable* reflection_ = ReflectionVTable::Global(); std::string GetValue(const char* key) const { - std::cout << "Key: " << key << std::endl; - std::cout << "All keys: ["; - - for (auto kv : jnode_->attrs) { - std::cout << kv.first << ", "; - } - std::cout << "]" << std::endl; auto it = jnode_->attrs.find(key); if (it == jnode_->attrs.end()) { - LOG(FATAL) << "JSONReader (AttrSetter): cannot find field " << key; + LOG(FATAL) << "JSONReader: cannot find field " << key; } return it->second; } @@ -575,7 +557,7 @@ struct JSONGraph { } } } - // ICHECK_EQ(topo_order.size(), n_nodes) << "Cyclic reference detected in JSON file"; + ICHECK_EQ(topo_order.size(), n_nodes) << "Cyclic reference detected in JSON file"; std::reverse(std::begin(topo_order), std::end(topo_order)); return topo_order; } diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 2d8eab6e3b3b..5358e61a3b88 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -262,7 +262,6 @@ def test_func_cycle(): "b64ndarrays": [], } dump = json.dumps(data) - print("Done dumping") func = tvm.ir.load_json(dump) assert isinstance(func, relay.Function) assert not func.virtual_device_ From 435ad5f8ff2bd80d3827ea04245092f96b2590d0 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 17:33:02 -0800 Subject: [PATCH 06/16] Fix tests --- python/tvm/ir/json_compact.py | 6 ++++-- tests/python/relay/test_json_compact.py | 9 +++++---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 2be329057dfa..ec8cd6c0a4b2 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -68,7 +68,8 @@ def create_updater_08_to_09(): """ def _initialize_virtual_device(item, _): - item["attrs"]["virtual_device_"] = "0" + if ("virtual_device_" not in item["attrs"].keys()): + item["attrs"]["virtual_device_"] = "0" return item node_map = { @@ -95,7 +96,8 @@ def create_updater_07_to_08(): def _initialize_module_attributes(item, _): assert item["type_key"] == "IRModule", "Only initialize the attributes for IRModules" - item["attrs"]["attrs"] = "0" + if "attrs" not in item["attrs"].keys(): + item["attrs"]["attrs"] = "0" return item node_map = {"IRModule": _initialize_module_attributes} diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 5358e61a3b88..528ee271091d 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -214,6 +214,7 @@ def test_str_map(): def test_irmodule_attributes(): nodes = [ + {"type_key": ""}, { "type_key": "IRModule", "attrs": { @@ -223,7 +224,7 @@ def test_irmodule_attributes(): "source_map": "0", "type_definitions": "0", }, - } + }, ] data = { "root": 1, @@ -239,9 +240,9 @@ def test_irmodule_attributes(): # 0.8 BACKWARDS COMPATIBILITY TESTS -# Does this break with functions? Yes. Seems bad. Probably should remove json dep checker? -def test_func_cycle(): +def test_virtual_device(): nodes = [ + {"type_key": ""}, { "type_key": "relay.Function", "attrs": { @@ -253,7 +254,7 @@ def test_func_cycle(): "span": "0", "type_params": "0", }, - } + }, ] data = { "root": 1, From c67e4e8104add537240ac9e2287c6e149b9b3d7f Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 17:37:13 -0800 Subject: [PATCH 07/16] formatting --- python/tvm/ir/json_compact.py | 2 +- tests/python/relay/test_json_compact.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index ec8cd6c0a4b2..4f49e4a641eb 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -68,7 +68,7 @@ def create_updater_08_to_09(): """ def _initialize_virtual_device(item, _): - if ("virtual_device_" not in item["attrs"].keys()): + if "virtual_device_" not in item["attrs"].keys(): item["attrs"]["virtual_device_"] = "0" return item diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 528ee271091d..b4418b043c8c 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -240,6 +240,7 @@ def test_irmodule_attributes(): # 0.8 BACKWARDS COMPATIBILITY TESTS + def test_virtual_device(): nodes = [ {"type_key": ""}, @@ -268,12 +269,6 @@ def test_virtual_device(): assert not func.virtual_device_ -# add module attributes and virtual device test - -# BACKWARD COMPAT WITH 0.8 TESTS - -# add test module attrs and test virtual device - if __name__ == "__main__": test_op() test_type_var() From 2bd0a01c883df06caa1b9d970e393b45d04f39ed Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 17:40:20 -0800 Subject: [PATCH 08/16] format : --- tests/python/relay/test_json_compact.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index b4418b043c8c..5a7084eb53ab 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -263,8 +263,7 @@ def test_virtual_device(): "attrs": {"tvm_version": "0.8.0"}, "b64ndarrays": [], } - dump = json.dumps(data) - func = tvm.ir.load_json(dump) + func = tvm.ir.load_json(json.dumps(data)) assert isinstance(func, relay.Function) assert not func.virtual_device_ From c851c4d78a10da6c959a3f39d4f836fce92d1d13 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 3 Jan 2022 16:59:59 -0800 Subject: [PATCH 09/16] Check that virtual id is unchanged in WithFields --- src/relay/ir/expr.cc | 23 ++++++++++++++++------- src/relay/ir/function.cc | 4 +++- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index f8cb4f0728e0..10090acb880d 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -100,7 +100,8 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, all_fields_unchanged = false; } - all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span); + all_fields_unchanged = all_fields_unchanged && virtual_device.same_as(tuple->virtual_device_) && + span.same_as(tuple->span); if (!all_fields_unchanged) { TupleNode* cow_tuple_node = tuple.CopyOnWrite(); cow_tuple_node->fields = fields; @@ -139,7 +140,7 @@ Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation Span span = opt_span.value_or(var->span); bool unchanged = vid.same_as(var->vid) && type_annotation.same_as(var->type_annotation) && - span.same_as(var->span); + virtual_device.same_as(var->virtual_device_) && span.same_as(var->span); if (!unchanged) { VarNode* cow_var_node = var.CopyOnWrite(); @@ -188,7 +189,8 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args VirtualDevice virtual_device = opt_virtual_device.value_or(call->virtual_device()); Span span = opt_span.value_or(call->span); - bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && span.same_as(call->span); + bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && + virtual_device.same_as(call->virtual_device_) && span.same_as(call->span); // Check that the args are unchanged if (unchanged) { @@ -261,7 +263,7 @@ Let WithFields(Let let, Optional opt_var, Optional opt_value, Optiona Span span = opt_span.value_or(let->span); bool unchanged = var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body) && - span.same_as(let->span); + virtual_device.same_as(let->virtual_device_) && span.same_as(let->span); if (!unchanged) { LetNode* cow_let_node = let.CopyOnWrite(); @@ -305,7 +307,8 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc Span span = opt_span.value_or(if_expr->span); bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) && - false_branch.same_as(if_expr->false_branch) && span.same_as(if_expr->span); + false_branch.same_as(if_expr->false_branch) && + virtual_device.same_as(if_expr->virtual_device_) && span.same_as(if_expr->span); if (!unchanged) { IfNode* cow_if_node = if_expr.CopyOnWrite(); @@ -349,6 +352,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, Span span = opt_span.value_or(tuple_get_item->span); bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && + virtual_device.same_as(tuple_get_item->virtual_device_) && span.same_as(tuple_get_item->span); if (!unchanged) { TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite(); @@ -385,7 +389,9 @@ RefCreate WithFields(RefCreate ref_create, Optional opt_value, VirtualDevice virtual_device = opt_virtual_device.value_or(ref_create->virtual_device()); Span span = opt_span.value_or(ref_create->span); - bool unchanged = value.same_as(ref_create->value) && span.same_as(ref_create->span); + bool unchanged = value.same_as(ref_create->value) && + virtual_device.same_as(ref_create->virtual_device_) && + span.same_as(ref_create->span); if (!unchanged) { RefCreateNode* cow_ref_create_node = ref_create.CopyOnWrite(); cow_ref_create_node->value = value; @@ -420,7 +426,9 @@ RefRead WithFields(RefRead ref_read, Optional opt_ref, VirtualDevice virtual_device = opt_virtual_device.value_or(ref_read->virtual_device()); Span span = opt_span.value_or(ref_read->span); - bool unchanged = ref.same_as(ref_read->ref) && span.same_as(ref_read->span); + bool unchanged = ref.same_as(ref_read->ref) && + virtual_device.same_as(ref_read->virtual_device_) && + span.same_as(ref_read->span); if (!unchanged) { RefReadNode* cow_ref_read_node = ref_read.CopyOnWrite(); cow_ref_read_node->ref = ref; @@ -456,6 +464,7 @@ RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional o Span span = opt_span.value_or(ref_write->span); bool unchanged = ref.same_as(ref_write->ref) && value.same_as(ref_write->value) && + virtual_device.same_as(ref_write->virtual_device_) && span.same_as(ref_write->span); if (!unchanged) { RefWriteNode* cow_ref_write_node = ref_write.CopyOnWrite(); diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 43305402557a..0c54299a2c61 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -53,7 +53,9 @@ Function WithFields(Function function, Optional> opt_params, Optional Span span = opt_span.value_or(function->span); bool unchanged = body.same_as(function->body) && ret_type.same_as(function->ret_type) && - attrs.same_as(function->attrs) && span.same_as(function->span); + attrs.same_as(function->attrs) && + virtual_device.same_as(function->virtual_device_) && + span.same_as(function->span); // Check that all the type params are unchanged if (unchanged) { From 7009dfef8453997b53e4f43e2f78884ba8d457c8 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Tue, 4 Jan 2022 15:36:43 -0800 Subject: [PATCH 10/16] Set virtual_device_ to fully unconstrained in ctor --- include/tvm/ir/expr.h | 2 ++ src/relay/ir/expr.cc | 34 ++++++++++++++++++++-------------- src/relay/ir/function.cc | 3 ++- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 2cfc8467a1cb..c7d0f58e3d9f 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -180,6 +180,8 @@ class RelayExprNode : public BaseExprNode { * the call to the function or closure is stored (instead of where the function itself is stored). * The VirtualDevice's Target field describes how the body of the function should be compiled. * + * Set to VirtualDevice::FullyUnconstrained by default. + * * \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular * import. */ diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 10090acb880d..bab0ba415b05 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -28,10 +28,7 @@ namespace tvm { VirtualDevice RelayExprNode::virtual_device() const { - if (virtual_device_.defined()) { - return Downcast(this->virtual_device_); - } - return VirtualDevice::FullyUnconstrained(); + return Downcast(this->virtual_device_); } namespace relay { @@ -76,6 +73,7 @@ TensorType ConstantNode::tensor_type() const { Tuple::Tuple(tvm::Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -100,7 +98,7 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, all_fields_unchanged = false; } - all_fields_unchanged = all_fields_unchanged && virtual_device.same_as(tuple->virtual_device_) && + all_fields_unchanged = all_fields_unchanged && virtual_device.same_as(tuple->virtual_device()) && span.same_as(tuple->span); if (!all_fields_unchanged) { TupleNode* cow_tuple_node = tuple.CopyOnWrite(); @@ -121,6 +119,7 @@ Var::Var(Id vid, Type type_annotation, Span span) { ObjectPtr n = make_object(); n->vid = std::move(vid); n->type_annotation = std::move(type_annotation); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -140,13 +139,13 @@ Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation Span span = opt_span.value_or(var->span); bool unchanged = vid.same_as(var->vid) && type_annotation.same_as(var->type_annotation) && - virtual_device.same_as(var->virtual_device_) && span.same_as(var->span); + virtual_device.same_as(var->virtual_device()) && span.same_as(var->span); if (!unchanged) { VarNode* cow_var_node = var.CopyOnWrite(); cow_var_node->vid = vid; cow_var_node->type_annotation = type_annotation; - cow_var_node->virtual_device_ = virtual_device; + cow_var_node->virtual_device_ = var->virtual_device(); cow_var_node->span = span; } return var; @@ -175,6 +174,7 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span s n->args = std::move(args); n->attrs = std::move(attrs); n->type_args = std::move(type_args); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -190,7 +190,7 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args Span span = opt_span.value_or(call->span); bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && - virtual_device.same_as(call->virtual_device_) && span.same_as(call->span); + virtual_device.same_as(call->virtual_device()) && span.same_as(call->span); // Check that the args are unchanged if (unchanged) { @@ -250,6 +250,7 @@ Let::Let(Var var, Expr value, Expr body, Span span) { n->var = std::move(var); n->value = std::move(value); n->body = std::move(body); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -263,7 +264,7 @@ Let WithFields(Let let, Optional opt_var, Optional opt_value, Optiona Span span = opt_span.value_or(let->span); bool unchanged = var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body) && - virtual_device.same_as(let->virtual_device_) && span.same_as(let->span); + virtual_device.same_as(let->virtual_device()) && span.same_as(let->span); if (!unchanged) { LetNode* cow_let_node = let.CopyOnWrite(); @@ -293,6 +294,7 @@ If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { n->cond = std::move(cond); n->true_branch = std::move(true_branch); n->false_branch = std::move(false_branch); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -308,7 +310,7 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) && false_branch.same_as(if_expr->false_branch) && - virtual_device.same_as(if_expr->virtual_device_) && span.same_as(if_expr->span); + virtual_device.same_as(if_expr->virtual_device()) && span.same_as(if_expr->span); if (!unchanged) { IfNode* cow_if_node = if_expr.CopyOnWrite(); @@ -339,6 +341,7 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { ObjectPtr n = make_object(); n->tuple = std::move(tuple); n->index = index; + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -352,7 +355,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, Span span = opt_span.value_or(tuple_get_item->span); bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && - virtual_device.same_as(tuple_get_item->virtual_device_) && + virtual_device.same_as(tuple_get_item->virtual_device()) && span.same_as(tuple_get_item->span); if (!unchanged) { TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite(); @@ -379,6 +382,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) RefCreate::RefCreate(Expr value, Span span) { ObjectPtr n = make_object(); n->value = std::move(value); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -390,7 +394,7 @@ RefCreate WithFields(RefCreate ref_create, Optional opt_value, Span span = opt_span.value_or(ref_create->span); bool unchanged = value.same_as(ref_create->value) && - virtual_device.same_as(ref_create->virtual_device_) && + virtual_device.same_as(ref_create->virtual_device()) && span.same_as(ref_create->span); if (!unchanged) { RefCreateNode* cow_ref_create_node = ref_create.CopyOnWrite(); @@ -416,6 +420,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) RefRead::RefRead(Expr ref, Span span) { ObjectPtr n = make_object(); n->ref = std::move(ref); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -427,7 +432,7 @@ RefRead WithFields(RefRead ref_read, Optional opt_ref, Span span = opt_span.value_or(ref_read->span); bool unchanged = ref.same_as(ref_read->ref) && - virtual_device.same_as(ref_read->virtual_device_) && + virtual_device.same_as(ref_read->virtual_device()) && span.same_as(ref_read->span); if (!unchanged) { RefReadNode* cow_ref_read_node = ref_read.CopyOnWrite(); @@ -452,6 +457,7 @@ RefWrite::RefWrite(Expr ref, Expr value, Span span) { ObjectPtr n = make_object(); n->ref = std::move(ref); n->value = std::move(value); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -464,7 +470,7 @@ RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional o Span span = opt_span.value_or(ref_write->span); bool unchanged = ref.same_as(ref_write->ref) && value.same_as(ref_write->value) && - virtual_device.same_as(ref_write->virtual_device_) && + virtual_device.same_as(ref_write->virtual_device()) && span.same_as(ref_write->span); if (!unchanged) { RefWriteNode* cow_ref_write_node = ref_write.CopyOnWrite(); diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 0c54299a2c61..bf0dd577a4d2 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -36,6 +36,7 @@ Function::Function(tvm::Array params, Expr body, Type ret_type, n->ret_type = std::move(ret_type); n->type_params = std::move(type_params); n->attrs = std::move(attrs); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -54,7 +55,7 @@ Function WithFields(Function function, Optional> opt_params, Optional bool unchanged = body.same_as(function->body) && ret_type.same_as(function->ret_type) && attrs.same_as(function->attrs) && - virtual_device.same_as(function->virtual_device_) && + virtual_device.same_as(function->virtual_device()) && span.same_as(function->span); // Check that all the type params are unchanged From 6427df257aeb76b02fdca0ea1d746b3288266d00 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 5 Jan 2022 16:54:02 -0800 Subject: [PATCH 11/16] visit virtual device in the attr visitor Fix serialization tests --- include/tvm/relay/expr.h | 1 + src/relay/ir/expr.cc | 6 +++++- src/relay/transforms/de_duplicate.cc | 1 + tests/python/relay/test_json_compact.py | 15 ++++++++++++--- 4 files changed, 19 insertions(+), 4 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index dcb7838a1b72..fe570806922f 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -72,6 +72,7 @@ class ConstantNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index bab0ba415b05..5fda138f42fb 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -28,6 +28,10 @@ namespace tvm { VirtualDevice RelayExprNode::virtual_device() const { + ICHECK(this->virtual_device_.defined()) + << "Virtual device should always be defined for all RelayNodes except Match and Clause. " + "Found this relay expression without a virtual device: \n" + << PrettyPrint(GetRef(this)); return Downcast(this->virtual_device_); } @@ -145,7 +149,7 @@ Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation VarNode* cow_var_node = var.CopyOnWrite(); cow_var_node->vid = vid; cow_var_node->type_annotation = type_annotation; - cow_var_node->virtual_device_ = var->virtual_device(); + cow_var_node->virtual_device_ = virtual_device; cow_var_node->span = span; } return var; diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index 2fd88736bf31..b3e88376abcb 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -52,6 +52,7 @@ Expr DeDup(const Expr& e) { Expr DispatchVisitExpr(const Expr& e) final { auto ret = ExprMutator::VisitExpr(e); ret->checked_type_ = e->checked_type_; + ret->virtual_device_ = e->virtual_device_; return ret; } diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 5a7084eb53ab..459ebb6d1da4 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -52,7 +52,13 @@ def test_var(): {"type_key": ""}, { "type_key": "relay.Var", - "attrs": {"_checked_type_": "0", "span": "0", "type_annotation": "0", "vid": "2"}, + "attrs": { + "_checked_type_": "0", + "span": "0", + "type_annotation": "0", + "vid": "2", + "virtual_device_": "0", + }, }, {"type_key": "relay.Id", "attrs": {"name_hint": "a3"}}, {"type_key": "relay.TensorType", "attrs": {"dtype": "float32", "shape": "4", "span": "0"}}, @@ -120,7 +126,7 @@ def test_global_var(): {"type_key": ""}, { "type_key": "relay.GlobalVar", - "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"}, + "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0", "virtual_device_": "0"}, }, ] data = { @@ -133,7 +139,10 @@ def test_global_var(): assert isinstance(tvar, tvm.ir.GlobalVar) nodes = [ {"type_key": ""}, - {"type_key": "GlobalVar", "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"}}, + { + "type_key": "GlobalVar", + "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0", "virtual_device_": "0"}, + }, ] data = { "root": 1, From dd196457195776a721b0496da3964dad5340121f Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 17:47:07 -0800 Subject: [PATCH 12/16] Fix tests after bad merge --- tests/python/relay/test_json_compact.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 459ebb6d1da4..56e82862bc18 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -57,7 +57,6 @@ def test_var(): "span": "0", "type_annotation": "0", "vid": "2", - "virtual_device_": "0", }, }, {"type_key": "relay.Id", "attrs": {"name_hint": "a3"}}, @@ -126,7 +125,7 @@ def test_global_var(): {"type_key": ""}, { "type_key": "relay.GlobalVar", - "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0", "virtual_device_": "0"}, + "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"}, }, ] data = { @@ -141,7 +140,7 @@ def test_global_var(): {"type_key": ""}, { "type_key": "GlobalVar", - "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0", "virtual_device_": "0"}, + "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"}, }, ] data = { From ad9eb3e587aa5a702e0f9ff67f38d26664b5ba19 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Fri, 7 Jan 2022 17:57:48 -0800 Subject: [PATCH 13/16] Change virtual_device() getter method --- src/relay/ir/expr.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 5fda138f42fb..64d921efe6a6 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -28,10 +28,12 @@ namespace tvm { VirtualDevice RelayExprNode::virtual_device() const { - ICHECK(this->virtual_device_.defined()) - << "Virtual device should always be defined for all RelayNodes except Match and Clause. " - "Found this relay expression without a virtual device: \n" - << PrettyPrint(GetRef(this)); + if (!this->virtual_device_.defined()) { + // virtual_device_ should always be defined, unless we imported this node from JSON using an old + // version of TVM, in which case we want to set it to the default, which is + // VirtualDevice::FullyUnconstrained(). + return VirtualDevice::FullyUnconstrained(); + } return Downcast(this->virtual_device_); } From 4769f419e09fd0f047974b9d9eeecb194527a46d Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 11 Jan 2022 11:10:34 -0800 Subject: [PATCH 14/16] lint --- python/tvm/ir/json_compact.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 55218d71fb5e..9666475b8039 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -69,7 +69,7 @@ def create_updater_08_to_09(): def _initialize_virtual_device(item, _): if "virtual_device_" not in item["attrs"]: - item["attrs"]["virtual_device_"] = "0" + item["attrs"]["virtual_device_"] = "0" return item node_map = { From 9906062bb9e345e2361c92bd5e3375f68715d485 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 11 Jan 2022 11:20:14 -0800 Subject: [PATCH 15/16] ci failed From abe91c44a7ed25682635f37bf93883232854011e Mon Sep 17 00:00:00 2001 From: electriclilies Date: Tue, 11 Jan 2022 12:05:00 -0800 Subject: [PATCH 16/16] ci was broken