diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 2cfc8467a1cb..c7d0f58e3d9f 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -180,6 +180,8 @@ class RelayExprNode : public BaseExprNode { * the call to the function or closure is stored (instead of where the function itself is stored). * The VirtualDevice's Target field describes how the body of the function should be compiled. * + * Set to VirtualDevice::FullyUnconstrained by default. + * * \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular * import. */ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index dcb7838a1b72..fe570806922f 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -72,6 +72,7 @@ class ConstantNode : public ExprNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("data", &data); + v->Visit("virtual_device_", &virtual_device_); v->Visit("span", &span); v->Visit("_checked_type_", &checked_type_); } diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index f8cb4f0728e0..64d921efe6a6 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -28,10 +28,13 @@ namespace tvm { VirtualDevice RelayExprNode::virtual_device() const { - if (virtual_device_.defined()) { - return Downcast(this->virtual_device_); + if (!this->virtual_device_.defined()) { + // virtual_device_ should always be defined, unless we imported this node from JSON using an old + // version of TVM, in which case we want to set it to the default, which is + // VirtualDevice::FullyUnconstrained(). + return VirtualDevice::FullyUnconstrained(); } - return VirtualDevice::FullyUnconstrained(); + return Downcast(this->virtual_device_); } namespace relay { @@ -76,6 +79,7 @@ TensorType ConstantNode::tensor_type() const { Tuple::Tuple(tvm::Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -100,7 +104,8 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, all_fields_unchanged = false; } - all_fields_unchanged = all_fields_unchanged && span.same_as(tuple->span); + all_fields_unchanged = all_fields_unchanged && virtual_device.same_as(tuple->virtual_device()) && + span.same_as(tuple->span); if (!all_fields_unchanged) { TupleNode* cow_tuple_node = tuple.CopyOnWrite(); cow_tuple_node->fields = fields; @@ -120,6 +125,7 @@ Var::Var(Id vid, Type type_annotation, Span span) { ObjectPtr n = make_object(); n->vid = std::move(vid); n->type_annotation = std::move(type_annotation); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -139,7 +145,7 @@ Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation Span span = opt_span.value_or(var->span); bool unchanged = vid.same_as(var->vid) && type_annotation.same_as(var->type_annotation) && - span.same_as(var->span); + virtual_device.same_as(var->virtual_device()) && span.same_as(var->span); if (!unchanged) { VarNode* cow_var_node = var.CopyOnWrite(); @@ -174,6 +180,7 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span s n->args = std::move(args); n->attrs = std::move(attrs); n->type_args = std::move(type_args); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -188,7 +195,8 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args VirtualDevice virtual_device = opt_virtual_device.value_or(call->virtual_device()); Span span = opt_span.value_or(call->span); - bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && span.same_as(call->span); + bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && + virtual_device.same_as(call->virtual_device()) && span.same_as(call->span); // Check that the args are unchanged if (unchanged) { @@ -248,6 +256,7 @@ Let::Let(Var var, Expr value, Expr body, Span span) { n->var = std::move(var); n->value = std::move(value); n->body = std::move(body); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -261,7 +270,7 @@ Let WithFields(Let let, Optional opt_var, Optional opt_value, Optiona Span span = opt_span.value_or(let->span); bool unchanged = var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body) && - span.same_as(let->span); + virtual_device.same_as(let->virtual_device()) && span.same_as(let->span); if (!unchanged) { LetNode* cow_let_node = let.CopyOnWrite(); @@ -291,6 +300,7 @@ If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { n->cond = std::move(cond); n->true_branch = std::move(true_branch); n->false_branch = std::move(false_branch); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -305,7 +315,8 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc Span span = opt_span.value_or(if_expr->span); bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) && - false_branch.same_as(if_expr->false_branch) && span.same_as(if_expr->span); + false_branch.same_as(if_expr->false_branch) && + virtual_device.same_as(if_expr->virtual_device()) && span.same_as(if_expr->span); if (!unchanged) { IfNode* cow_if_node = if_expr.CopyOnWrite(); @@ -336,6 +347,7 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { ObjectPtr n = make_object(); n->tuple = std::move(tuple); n->index = index; + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -349,6 +361,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, Span span = opt_span.value_or(tuple_get_item->span); bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && + virtual_device.same_as(tuple_get_item->virtual_device()) && span.same_as(tuple_get_item->span); if (!unchanged) { TupleGetItemNode* cow_tuple_get_item_node = tuple_get_item.CopyOnWrite(); @@ -375,6 +388,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) RefCreate::RefCreate(Expr value, Span span) { ObjectPtr n = make_object(); n->value = std::move(value); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -385,7 +399,9 @@ RefCreate WithFields(RefCreate ref_create, Optional opt_value, VirtualDevice virtual_device = opt_virtual_device.value_or(ref_create->virtual_device()); Span span = opt_span.value_or(ref_create->span); - bool unchanged = value.same_as(ref_create->value) && span.same_as(ref_create->span); + bool unchanged = value.same_as(ref_create->value) && + virtual_device.same_as(ref_create->virtual_device()) && + span.same_as(ref_create->span); if (!unchanged) { RefCreateNode* cow_ref_create_node = ref_create.CopyOnWrite(); cow_ref_create_node->value = value; @@ -410,6 +426,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) RefRead::RefRead(Expr ref, Span span) { ObjectPtr n = make_object(); n->ref = std::move(ref); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -420,7 +437,9 @@ RefRead WithFields(RefRead ref_read, Optional opt_ref, VirtualDevice virtual_device = opt_virtual_device.value_or(ref_read->virtual_device()); Span span = opt_span.value_or(ref_read->span); - bool unchanged = ref.same_as(ref_read->ref) && span.same_as(ref_read->span); + bool unchanged = ref.same_as(ref_read->ref) && + virtual_device.same_as(ref_read->virtual_device()) && + span.same_as(ref_read->span); if (!unchanged) { RefReadNode* cow_ref_read_node = ref_read.CopyOnWrite(); cow_ref_read_node->ref = ref; @@ -444,6 +463,7 @@ RefWrite::RefWrite(Expr ref, Expr value, Span span) { ObjectPtr n = make_object(); n->ref = std::move(ref); n->value = std::move(value); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -456,6 +476,7 @@ RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional o Span span = opt_span.value_or(ref_write->span); bool unchanged = ref.same_as(ref_write->ref) && value.same_as(ref_write->value) && + virtual_device.same_as(ref_write->virtual_device()) && span.same_as(ref_write->span); if (!unchanged) { RefWriteNode* cow_ref_write_node = ref_write.CopyOnWrite(); diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 43305402557a..bf0dd577a4d2 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -36,6 +36,7 @@ Function::Function(tvm::Array params, Expr body, Type ret_type, n->ret_type = std::move(ret_type); n->type_params = std::move(type_params); n->attrs = std::move(attrs); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } @@ -53,7 +54,9 @@ Function WithFields(Function function, Optional> opt_params, Optional Span span = opt_span.value_or(function->span); bool unchanged = body.same_as(function->body) && ret_type.same_as(function->ret_type) && - attrs.same_as(function->attrs) && span.same_as(function->span); + attrs.same_as(function->attrs) && + virtual_device.same_as(function->virtual_device()) && + span.same_as(function->span); // Check that all the type params are unchanged if (unchanged) { diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index 2fd88736bf31..b3e88376abcb 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -52,6 +52,7 @@ Expr DeDup(const Expr& e) { Expr DispatchVisitExpr(const Expr& e) final { auto ret = ExprMutator::VisitExpr(e); ret->checked_type_ = e->checked_type_; + ret->virtual_device_ = e->virtual_device_; return ret; } diff --git a/tests/python/relay/test_json_compact.py b/tests/python/relay/test_json_compact.py index 5a7084eb53ab..56e82862bc18 100644 --- a/tests/python/relay/test_json_compact.py +++ b/tests/python/relay/test_json_compact.py @@ -52,7 +52,12 @@ def test_var(): {"type_key": ""}, { "type_key": "relay.Var", - "attrs": {"_checked_type_": "0", "span": "0", "type_annotation": "0", "vid": "2"}, + "attrs": { + "_checked_type_": "0", + "span": "0", + "type_annotation": "0", + "vid": "2", + }, }, {"type_key": "relay.Id", "attrs": {"name_hint": "a3"}}, {"type_key": "relay.TensorType", "attrs": {"dtype": "float32", "shape": "4", "span": "0"}}, @@ -133,7 +138,10 @@ def test_global_var(): assert isinstance(tvar, tvm.ir.GlobalVar) nodes = [ {"type_key": ""}, - {"type_key": "GlobalVar", "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"}}, + { + "type_key": "GlobalVar", + "attrs": {"_checked_type_": "0", "name_hint": "x", "span": "0"}, + }, ] data = { "root": 1,