From 0fe8291907a097996bb041464eecba5870f51ba3 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Thu, 2 Dec 2021 15:49:04 -0800 Subject: [PATCH 1/5] Add virtual device as a first class field to Relay nodes --- include/tvm/ir/expr.h | 27 ++++++++++ include/tvm/relay/expr.h | 102 +++++++++++++++++++++-------------- include/tvm/relay/function.h | 3 ++ src/relay/ir/expr.cc | 51 ++++++++++++++---- src/relay/ir/function.cc | 5 +- 5 files changed, 138 insertions(+), 50 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b910d32ceca4..310d20812ade 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -39,6 +39,9 @@ namespace tvm { using tvm::runtime::String; +// Forward-declare SEScope to avoid circular imports. +class SEScope; + /*! * \brief Base type of all the expressions. * \sa Expr @@ -165,6 +168,30 @@ class RelayExprNode : public BaseExprNode { template inline const TTypeNode* type_as() const; + /*! + * \brief The virtual device (SEScope) for this node (the result of device planning). + * For first-order expressions (non functions), this describes where the result of evaluating the + * expression should be stored. Note that currently, all composite first-order values (tuples, + * references, ADTs) must be stored on the same virtual device. This means that it is not possible + * to store two tuple fields on different devices, so we only need one virtual device for these + * types. + * + * For expressions that have the function type, the virtual device describes where the result of + * the call to the function or closure is stored (instead of where the function itself is stored). + * The SEScope's Target field describes how the body of the function should be compiled. + * + * \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular + * import. We can forward-declare the SEScope type for the getter function, but not for the field + * itself. + */ + mutable ObjectRef virtual_device_; + + /*! + * \return The virtual device (SEScope). + * If the virtual device is not defined, returns SEScope::FullyUnconstrained(). + */ + SEScope virtual_device() const; + static constexpr const char* _type_key = "RelayExpr"; static constexpr const uint32_t _type_child_slots = 22; TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode); diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index f57b2d1a1952..0f45970c9ea1 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -151,10 +152,14 @@ class Tuple : public Expr { * \param tuple The tuple to copy * \param opt_fields The (optional) fields for the copied tuple. If none, ret_tuple->fields = * tuple->fields. - * \param opt_span The (optional) span for the copied tuple. If none, ret_tuple->span = tuple->span. + * \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none, + * ret_tuple->virtual_device = tuple->virtual_device. + * \param opt_span The (optional) span for the copied tuple. If none, + * ret_tuple->span = tuple->span. */ Tuple WithFields(Tuple tuple, Optional> opt_fields = Optional>(), - Optional opt_span = Optional(nullptr)); + Optional opt_virtual_device = Optional(), + Optional opt_span = Optional()); /*! * \brief Local variables used in the let expression. @@ -240,14 +245,16 @@ class Var : public Expr { * \param opt_vid The (optional) vid for the copied var. If none, ret_var->vid = var->vid. * \param opt_type_annotation The (optional) type_annotation for the copied var. If none, * ret_var->type_annotation = var->type_annotation. - * \param opt_span The (optional) span for the copied var. If none, ret_var->span = var->span. - * \return If all properties are null or the same as the property in the input var - * (i.e., opt_vid is null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise, - * we return a copy of call with the different fields overwritten. (i.e., if - * opt_vid.value() != var->vid, then ret_var->vid = opt_.value()). + * \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none, + * ret_tuple->virtual_device = tuple->virtual_device. \param opt_span The (optional) span for the + * copied var. If none, ret_var->span = var->span. \return If all properties are null or the same as + * the property in the input var (i.e., opt_vid is null or opt_vid.value() == var->vid, etc.), then + * we return var. Otherwise, we return a copy of call with the different fields overwritten. (i.e., + * if opt_vid.value() != var->vid, then ret_var->vid = opt_.value()). */ Var WithFields(Var var, Optional opt_vid = Optional(), Optional opt_type_annotation = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! @@ -362,16 +369,18 @@ class Call : public Expr { * call->attrs. * \param opt_type_args The (optional) type args for the copied call. If none, * ret_call->type_args = call->type_args. - * \param opt_span The (optional) span for the copied call. If none, ret_call->span = call->span. - * \return If all properties are null or the same as the property in the input call - * (i.e., opt_op is null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we - * return a copy of call with the different fields overwritten. (i.e., if opt_op.value() != - * call->op, then ret_call->op = opt_op.value()). + * \param opt_virtual_device The (optional) virtual_device for the copied call. If none, + * ret_call->virtual_device = call->virtual_device. \param opt_span The (optional) span for the + * copied call. If none, ret_call->span = call->span. \return If all properties are null or the same + * as the property in the input call (i.e., opt_op is null or opt_op.value() == call->op, etc.), + * then we return call. Otherwise, we return a copy of call with the different fields overwritten. + * (i.e., if opt_op.value() != call->op, then ret_call->op = opt_op.value()). */ Call WithFields(Call call, Optional opt_op = Optional(), Optional> opt_args = Optional>(), Optional opt_attrs = Optional(), Optional> opt_type_args = Optional>(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! @@ -456,15 +465,17 @@ class Let : public Expr { * \param opt_var The (optional) var for the copied let. If none, ret_let->op = let->op. * \param opt_value The (optional) value for the copied let. If none, ret_let->args = let->args. * \param opt_body The (optional) body for the copied let. If none, ret_let->attrs = let->attrs. - * \param opt_span The (optional) span for the copied let. If none, ret_let->span = let->span. - * \return If all properties are null or the same as the property in the input let (i.e., opt_var is - * null or opt_var.value() == let->var, etc.), then we return let. Otherwise, we return a copy of - * let with the different fields overwritten. (i.e., if opt_var.value() != let->var, then - * ret_let->var = opt_var.value()). + * \param opt_virtual_device The (optional) virtual_device for the copied let. If none, + * ret_let->virtual_device = let->virtual_device. \param opt_span The (optional) span for the copied + * let. If none, ret_let->span = let->span. \return If all properties are null or the same as the + * property in the input let (i.e., opt_var is null or opt_var.value() == let->var, etc.), then we + * return let. Otherwise, we return a copy of let with the different fields overwritten. (i.e., if + * opt_var.value() != let->var, then ret_let->var = opt_var.value()). */ Let WithFields(Let let, Optional opt_var = Optional(), Optional opt_value = Optional(), Optional opt_body = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! @@ -539,17 +550,19 @@ class If : public Expr { * ret_if->true_branch = ret_if->false_branch. * \param opt_false_branch The (optional) false_branch * for the copied if_expr. If none, ret_if->false_branch = if_expr->false_branch. - * \param opt_span - * The (optional) span for the copied if_expr. If none, ret_if->span = if_expr->span. - * \return If all - * properties are null or the same as the property in the input if_expr (i.e., opt_cond is null or - * opt_cond.value() == if_expr->cond, etc.), then we return if_expr. Otherwise, we return a copy of - * if_expr with the different fields overwritten. (i.e., if opt_cond.value() != if_expr->cond, then - * ret_if->cond = opt_cond.value()). + * \param opt_virtual_device The (optional) virtual_device for the copied if_expr. If none, + * ret_if->virtual_device = if_expr->virtual_device. + * \param opt_span The (optional) span for the copied if_expr. If none, + * ret_if->span = if_expr->span. + * \return If all properties are null or the same as the property in + * the input if_expr (i.e., opt_cond is null or opt_cond.value() == if_expr->cond, etc.), then we + * return if_expr. Otherwise, we return a copy of if_expr with the different fields overwritten. + * (i.e., if opt_cond.value() != if_expr->cond, then ret_if->cond = opt_cond.value()). */ If WithFields(If if_expr, Optional opt_cond = Optional(), Optional opt_true_branch = Optional(), Optional opt_false_branch = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! \brief Get index-th field out of a tuple. */ @@ -603,8 +616,9 @@ class TupleGetItem : public Expr { * ret_tuple_get_item->tuple = tuple_get_item->tuple. * \param opt_index The (optional) index for the copied tuple_get_item. If none, * ret_tuple_get_item->index = tuple_get_item->index. - * \param - * opt_span The (optional) span for the copied tuple_get_item. If none, + * \param opt_virtual_device The (optional) virtual_device for the copied tuple_get_item. + * If none, ret_tuple_get_item->virtual_device = tuple_get_item->virtual_device. + * \param opt_span The (optional) span for the copied tuple_get_item. If none, * ret_tuple_get_item->span = tuple_get_item->span. * \return If all properties are null or the same as the property in the input tuple_get_item * (i.e., opt_tuple is null or opt_tuple.value() == tuple_get_item->tuple, etc.), then we return @@ -614,6 +628,7 @@ class TupleGetItem : public Expr { */ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), Optional opt_index = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! \brief Create a new Reference out of initial value. */ @@ -663,6 +678,8 @@ class RefCreate : public Expr { * \param ref_create The ref_create to copy. * \param opt_value The (optional) value for the copied ref_create. If none, * ret_ref_create->value = ref_create->value. + * \param opt_virtual_device The (optional) virtual_device for the copied ref_create. If none, + * ret_ref_create->virtual_device = ref_create->virtual_device. * \param opt_span The (optional) span for the copied ref_create. If none, * ret_ref_create->span = ref_create->span. * \return If all properties are null or the same as the property in the input ref_create @@ -672,6 +689,7 @@ class RefCreate : public Expr { * ret_ref_create->value = opt_value.value()). */ RefCreate WithFields(RefCreate ref_create, Optional opt_value = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! \brief Get value out of Reference. */ @@ -720,15 +738,17 @@ class RefRead : public Expr { * \param ref_read The ref_read to copy. * \param opt_ref The (optional) ref for the copied ref_read. If none, ret_ref_read->ref = * ref_read->ref. - * \param opt_span - * The (optional) span for the copied ref_read. If none, ret_ref_read->span = ref_read->span. - * \return If all properties are null or the same as the property in the input ref_read - * (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return ref_read. - * Otherwise, we return a copy of ref_read with the different fields overwritten. - * (i.e., if opt_ref.value() != ref_read->ref, then - * ret_ref_read->ref = opt_ref.value()). + * \param opt_virtual_device + * The (optional) virtual_device for the copied ref_read. If none, ret_ref_read->virtual_device = + * ref_read->virtual_device. \param opt_span The (optional) span for the copied ref_read. If none, + * ret_ref_read->span = ref_read->span. \return If all properties are null or the same as the + * property in the input ref_read (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), + * then we return ref_read. Otherwise, we return a copy of ref_read with the different fields + * overwritten. (i.e., if opt_ref.value() != ref_read->ref, then ret_ref_read->ref = + * opt_ref.value()). */ RefRead WithFields(RefRead ref_read, Optional opt_ref = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ @@ -784,16 +804,18 @@ class RefWrite : public Expr { * ret_ref_write->ref = ref_write->ref. * \param opt_value The (optional) value for the copied ref_write. If none, * ret_ref_write->value = ref_write->value. - * \param opt_span - * The (optional) span for the copied ref_write. If none, ret_ref_write->span = ref_write->span. - * \return If all properties are null or the same as the property in the input ref_write - * (i.e., opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write. - * Otherwise, we return a copy of ref_write with the different fields overwritten. - * (i.e., if ref_write.value() != ref_write->ref, then - * ret_ref_write->ref = opt_ref.value()). + * \param opt_virtual_device + * The (optional) virtual_device for the copied ref_write. If none, ret_ref_write->virtual_device = + * ref_write->virtual_device. \param opt_span The (optional) span for the copied ref_write. If none, + * ret_ref_write->span = ref_write->span. \return If all properties are null or the same as the + * property in the input ref_write (i.e., opt_ref is null or opt_ref.value() == ref_write->ref, + * etc.), then we return ref_write. Otherwise, we return a copy of ref_write with the different + * fields overwritten. (i.e., if ref_write.value() != ref_write->ref, then ret_ref_write->ref = + * opt_ref.value()). */ RefWrite WithFields(RefWrite ref_write, Optional opt_ref = Optional(), Optional opt_value = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index b0a867082a09..1b8ed4443456 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -134,6 +134,8 @@ class Function : public BaseFunc { * \param opt_attrs * The (optional) attributes for the copied function. If none, * ret_function->attrs = function->attrs. + * \param opt_virtual_device The (optional) virtual_device for the copied function. If none, + * ret_function->virtual_device = function->virtual_device. * \param opt_span The (optional) span for the copied function. If none, * ret_function->span = function->span. * \return If all properties are null or the same as the property in the input function @@ -146,6 +148,7 @@ Function WithFields(Function function, Optional> opt_params = Optiona Optional opt_ret_type = Optional(), Optional> opt_ty_params = Optional>(), Optional opt_attrs = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /* diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 8998f4e1573d..6b4b2f16ce1e 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -23,8 +23,17 @@ */ #include #include +#include namespace tvm { + +SEScope RelayExprNode::virtual_device() const { + if (virtual_device_.defined()) { + return Downcast(this->virtual_device_); + } + return SEScope::FullyUnconstrained(); +} + namespace relay { using tvm::ReprPrinter; @@ -76,8 +85,10 @@ TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array fields, Span span) { return Tuple(fields, span); }); -Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_span) { +Tuple WithFields(Tuple tuple, Optional> opt_fields, + Optional opt_virtual_device, Optional opt_span) { Array fields = opt_fields.value_or(tuple->fields); + SEScope virtual_device = opt_virtual_device.value_or(tuple->virtual_device()); Span span = opt_span.value_or(tuple->span); bool all_fields_unchanged = true; @@ -93,6 +104,7 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional o if (!all_fields_unchanged) { TupleNode* cow_tuple_node = tuple.CopyOnWrite(); cow_tuple_node->fields = fields; + cow_tuple_node->virtual_device_ = virtual_device; cow_tuple_node->span = span; } return std::move(tuple); @@ -113,9 +125,10 @@ Var::Var(Id vid, Type type_annotation, Span span) { } Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation, - Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Id vid = opt_vid.value_or(var->vid); Type type_annotation = opt_type_annotation.value_or(var->type_annotation); + SEScope virtual_device = opt_virtual_device.value_or(var->virtual_device()); Span span = opt_span.value_or(var->span); bool unchanged = vid.same_as(var->vid) && type_annotation.same_as(var->type_annotation) && @@ -125,6 +138,7 @@ Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation VarNode* cow_var_node = var.CopyOnWrite(); cow_var_node->vid = vid; cow_var_node->type_annotation = type_annotation; + cow_var_node->virtual_device_ = virtual_device; cow_var_node->span = span; } return std::move(var); @@ -159,11 +173,12 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span s Call WithFields(Call call, Optional opt_op, Optional> opt_args, Optional opt_attrs, Optional> opt_type_args, - Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Expr op = opt_op.value_or(call->op); Array args = opt_args.value_or(call->args); Attrs attrs = opt_attrs.value_or(call->attrs); Array type_args = opt_type_args.value_or(call->type_args); + SEScope virtual_device = opt_virtual_device.value_or(call->virtual_device()); Span span = opt_span.value_or(call->span); bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && span.same_as(call->span); @@ -201,6 +216,7 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args cow_call_node->args = args; cow_call_node->attrs = attrs; cow_call_node->type_args = type_args; + cow_call_node->virtual_device_ = virtual_device; cow_call_node->span = span; } return std::move(call); @@ -230,10 +246,11 @@ Let::Let(Var var, Expr value, Expr body, Span span) { } Let WithFields(Let let, Optional opt_var, Optional opt_value, Optional opt_body, - Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Var var = opt_var.value_or(let->var); Expr value = opt_value.value_or(let->value); Expr body = opt_body.value_or(let->body); + SEScope virtual_device = opt_virtual_device.value_or(let->virtual_device()); Span span = opt_span.value_or(let->span); bool unchanged = var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body) && @@ -244,6 +261,7 @@ Let WithFields(Let let, Optional opt_var, Optional opt_value, Optiona cow_let_node->var = var; cow_let_node->value = value; cow_let_node->body = body; + cow_let_node->virtual_device_ = virtual_device; cow_let_node->span = span; } return std::move(let); @@ -271,10 +289,12 @@ If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { } If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branch, - Optional opt_false_branch, Optional opt_span) { + Optional opt_false_branch, Optional opt_virtual_device, + Optional opt_span) { Expr cond = opt_cond.value_or(if_expr->cond); Expr true_branch = opt_true_branch.value_or(if_expr->true_branch); Expr false_branch = opt_false_branch.value_or(if_expr->false_branch); + SEScope virtual_device = opt_virtual_device.value_or(if_expr->virtual_device()); Span span = opt_span.value_or(if_expr->span); bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) && @@ -285,6 +305,8 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc cow_if_node->cond = cond; cow_if_node->true_branch = true_branch; cow_if_node->false_branch = false_branch; + cow_if_node->virtual_device_ = virtual_device; + cow_if_node->span = span; } return std::move(if_expr); } @@ -312,9 +334,11 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { } TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, - Optional opt_index, Optional opt_span) { + Optional opt_index, Optional opt_virtual_device, + Optional opt_span) { Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); Integer index = opt_index.value_or(tuple_get_item->index); + SEScope virtual_device = opt_virtual_device.value_or(tuple->virtual_device()); Span span = opt_span.value_or(tuple_get_item->span); bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && @@ -324,6 +348,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, cow_tuple_get_item_node->tuple = tuple; cow_tuple_get_item_node->index = index; cow_tuple_get_item_node->span = span; + cow_tuple_get_item_node->virtual_device_ = virtual_device; } return std::move(tuple_get_item); } @@ -347,14 +372,17 @@ RefCreate::RefCreate(Expr value, Span span) { data_ = std::move(n); } -RefCreate WithFields(RefCreate ref_create, Optional opt_value, Optional opt_span) { +RefCreate WithFields(RefCreate ref_create, Optional opt_value, + Optional opt_virtual_device, Optional opt_span) { Expr value = opt_value.value_or(ref_create->value); + SEScope virtual_device = opt_virtual_device.value_or(ref_create->virtual_device()); Span span = opt_span.value_or(ref_create->span); bool unchanged = value.same_as(ref_create->value) && span.same_as(ref_create->span); if (!unchanged) { RefCreateNode* cow_ref_create_node = ref_create.CopyOnWrite(); cow_ref_create_node->value = value; + cow_ref_create_node->virtual_device_ = virtual_device; cow_ref_create_node->span = span; } return std::move(ref_create); @@ -379,14 +407,17 @@ RefRead::RefRead(Expr ref, Span span) { data_ = std::move(n); } -RefRead WithFields(RefRead ref_read, Optional opt_ref, Optional opt_span) { +RefRead WithFields(RefRead ref_read, Optional opt_ref, Optional opt_virtual_device, + Optional opt_span) { Expr ref = opt_ref.value_or(ref_read->ref); + SEScope virtual_device = opt_virtual_device.value_or(ref_read->virtual_device()); Span span = opt_span.value_or(ref_read->span); bool unchanged = ref.same_as(ref_read->ref) && span.same_as(ref_read->span); if (!unchanged) { RefReadNode* cow_ref_read_node = ref_read.CopyOnWrite(); cow_ref_read_node->ref = ref; + cow_ref_read_node->virtual_device_ = virtual_device; cow_ref_read_node->span = span; } return std::move(ref_read); @@ -411,9 +442,10 @@ RefWrite::RefWrite(Expr ref, Expr value, Span span) { } RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional opt_value, - Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Expr ref = opt_ref.value_or(ref_write->ref); Expr value = opt_value.value_or(ref_write->value); + SEScope virtual_device = opt_virtual_device.value_or(ref_write->virtual_device()); Span span = opt_span.value_or(ref_write->span); bool unchanged = ref.same_as(ref_write->ref) && value.same_as(ref_write->value) && @@ -422,6 +454,7 @@ RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional o RefWriteNode* cow_ref_write_node = ref_write.CopyOnWrite(); cow_ref_write_node->ref = ref; cow_ref_write_node->value = value; + cow_ref_write_node->virtual_device_ = virtual_device; cow_ref_write_node->span = span; } return std::move(ref_write); diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index f24dd6d1fb4f..f2cb02194009 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -42,12 +42,14 @@ Function::Function(tvm::Array params, Expr body, Type ret_type, Function WithFields(Function function, Optional> opt_params, Optional opt_body, Optional opt_ret_type, Optional> opt_ty_params, - Optional opt_attrs, Optional opt_span) { + Optional opt_attrs, Optional opt_virtual_device, + Optional opt_span) { Array params = opt_params.value_or(function->params); Expr body = opt_body.value_or(function->body); Type ret_type = opt_ret_type.value_or(function->ret_type); Array ty_params = opt_ty_params.value_or(function->type_params); DictAttrs attrs = opt_attrs.value_or(function->attrs); + SEScope virtual_device = opt_virtual_device.value_or(function->virtual_device()); Span span = opt_span.value_or(function->span); bool unchanged = body.same_as(function->body) && ret_type.same_as(function->ret_type) && @@ -86,6 +88,7 @@ Function WithFields(Function function, Optional> opt_params, Optional cow_function_node->ret_type = ret_type; cow_function_node->type_params = ty_params; cow_function_node->attrs = attrs; + cow_function_node->virtual_device_ = virtual_device; cow_function_node->span = span; } return std::move(function); From 1637380ae6700b296a9f9bc68bc6a53d2e1272b6 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 6 Dec 2021 15:21:59 -0800 Subject: [PATCH 2/5] Flaky test? From 6148fe927b2eb5a4571294aaf77af11e73b54862 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 8 Dec 2021 13:36:24 -0800 Subject: [PATCH 3/5] Comments --- include/tvm/ir/expr.h | 3 +- include/tvm/relay/expr.h | 59 ++++++++++++++++++++++------------------ 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 310d20812ade..d33606676944 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -181,8 +181,7 @@ class RelayExprNode : public BaseExprNode { * The SEScope's Target field describes how the body of the function should be compiled. * * \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular - * import. We can forward-declare the SEScope type for the getter function, but not for the field - * itself. + * import. */ mutable ObjectRef virtual_device_; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 0f45970c9ea1..03200d3a3dfb 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -246,11 +246,12 @@ class Var : public Expr { * \param opt_type_annotation The (optional) type_annotation for the copied var. If none, * ret_var->type_annotation = var->type_annotation. * \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none, - * ret_tuple->virtual_device = tuple->virtual_device. \param opt_span The (optional) span for the - * copied var. If none, ret_var->span = var->span. \return If all properties are null or the same as - * the property in the input var (i.e., opt_vid is null or opt_vid.value() == var->vid, etc.), then - * we return var. Otherwise, we return a copy of call with the different fields overwritten. (i.e., - * if opt_vid.value() != var->vid, then ret_var->vid = opt_.value()). + * ret_tuple->virtual_device = tuple->virtual_device. + * \param opt_span The (optional) span for the copied var. If none, ret_var->span = var->span. + * \return If all properties are null or the same as the property in the input var (i.e., opt_vid is + * null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise, we return a copy of + * call with the different fields overwritten. (i.e., if opt_vid.value() != var->vid, then + * ret_var->vid = opt_.value()). */ Var WithFields(Var var, Optional opt_vid = Optional(), Optional opt_type_annotation = Optional(), @@ -370,11 +371,12 @@ class Call : public Expr { * \param opt_type_args The (optional) type args for the copied call. If none, * ret_call->type_args = call->type_args. * \param opt_virtual_device The (optional) virtual_device for the copied call. If none, - * ret_call->virtual_device = call->virtual_device. \param opt_span The (optional) span for the - * copied call. If none, ret_call->span = call->span. \return If all properties are null or the same - * as the property in the input call (i.e., opt_op is null or opt_op.value() == call->op, etc.), - * then we return call. Otherwise, we return a copy of call with the different fields overwritten. - * (i.e., if opt_op.value() != call->op, then ret_call->op = opt_op.value()). + * ret_call->virtual_device = call->virtual_device. + * \param opt_span The (optional) span for the copied call. If none, ret_call->span = call->span. + * \return If all properties are null or the same as the property in the input call (i.e., opt_op is + * null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we return a copy of + * call with the different fields overwritten. (i.e., if opt_op.value() != call->op, then + * ret_call->op = opt_op.value()). */ Call WithFields(Call call, Optional opt_op = Optional(), Optional> opt_args = Optional>(), @@ -466,11 +468,12 @@ class Let : public Expr { * \param opt_value The (optional) value for the copied let. If none, ret_let->args = let->args. * \param opt_body The (optional) body for the copied let. If none, ret_let->attrs = let->attrs. * \param opt_virtual_device The (optional) virtual_device for the copied let. If none, - * ret_let->virtual_device = let->virtual_device. \param opt_span The (optional) span for the copied - * let. If none, ret_let->span = let->span. \return If all properties are null or the same as the - * property in the input let (i.e., opt_var is null or opt_var.value() == let->var, etc.), then we - * return let. Otherwise, we return a copy of let with the different fields overwritten. (i.e., if - * opt_var.value() != let->var, then ret_let->var = opt_var.value()). + * ret_let->virtual_device = let->virtual_device. + * \param opt_span The (optional) span for the copied let. If none, ret_let->span = let->span. + * \return If all properties are null or the same as the property in the input let (i.e., opt_var is + * null or opt_var.value() == let->var, etc.), then we return let. Otherwise, we return a copy of + * let with the different fields overwritten. (i.e., if opt_var.value() != let->var, then + * ret_let->var = opt_var.value()). */ Let WithFields(Let let, Optional opt_var = Optional(), Optional opt_value = Optional(), @@ -740,12 +743,13 @@ class RefRead : public Expr { * ref_read->ref. * \param opt_virtual_device * The (optional) virtual_device for the copied ref_read. If none, ret_ref_read->virtual_device = - * ref_read->virtual_device. \param opt_span The (optional) span for the copied ref_read. If none, - * ret_ref_read->span = ref_read->span. \return If all properties are null or the same as the - * property in the input ref_read (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), - * then we return ref_read. Otherwise, we return a copy of ref_read with the different fields - * overwritten. (i.e., if opt_ref.value() != ref_read->ref, then ret_ref_read->ref = - * opt_ref.value()). + * ref_read->virtual_device. + * \param opt_span The (optional) span for the copied ref_read. If none, ret_ref_read->span = + * ref_read->span. + * \return If all properties are null or the same as the property in the input + * ref_read (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return + * ref_read. Otherwise, we return a copy of ref_read with the different fields overwritten. (i.e., + * if opt_ref.value() != ref_read->ref, then ret_ref_read->ref = opt_ref.value()). */ RefRead WithFields(RefRead ref_read, Optional opt_ref = Optional(), Optional opt_virtual_device = Optional(), @@ -806,12 +810,13 @@ class RefWrite : public Expr { * ret_ref_write->value = ref_write->value. * \param opt_virtual_device * The (optional) virtual_device for the copied ref_write. If none, ret_ref_write->virtual_device = - * ref_write->virtual_device. \param opt_span The (optional) span for the copied ref_write. If none, - * ret_ref_write->span = ref_write->span. \return If all properties are null or the same as the - * property in the input ref_write (i.e., opt_ref is null or opt_ref.value() == ref_write->ref, - * etc.), then we return ref_write. Otherwise, we return a copy of ref_write with the different - * fields overwritten. (i.e., if ref_write.value() != ref_write->ref, then ret_ref_write->ref = - * opt_ref.value()). + * ref_write->virtual_device. + * \param opt_span The (optional) span for the copied ref_write. If none, ret_ref_write->span = + * ref_write->span. + * \return If all properties are null or the same as the property in the input ref_write (i.e., + * opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write. Otherwise, + * we return a copy of ref_write with the different fields overwritten. (i.e., if ref_write.value() + * != ref_write->ref, then ret_ref_write->ref = opt_ref.value()). */ RefWrite WithFields(RefWrite ref_write, Optional opt_ref = Optional(), Optional opt_value = Optional(), From f2dbc91c366724f85b0bec8fdef686381dafa48e Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 8 Dec 2021 16:44:08 -0800 Subject: [PATCH 4/5] Oops, lost these in force push --- rust/tvm-sys/Cargo.toml | 2 +- rust/tvm/src/ir/relay/mod.rs | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml index ccc104ba9223..4494e20afa31 100644 --- a/rust/tvm-sys/Cargo.toml +++ b/rust/tvm-sys/Cargo.toml @@ -85,4 +85,4 @@ enumn = "^0.1" [build-dependencies] bindgen = { version="0.57", default-features = false, features = ["runtime"] } anyhow = "^1.0" -tvm-build = "0.2.1" +tvm-build = "0.2.4" diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index b65b784bf400..404cca4946fb 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -40,6 +40,7 @@ pub mod attrs; pub struct ExprNode { pub base: BaseExprNode, pub checked_type: Type, + pub virtual_device: ObjectRef, } impl ExprNode { @@ -47,6 +48,7 @@ impl ExprNode { ExprNode { base: BaseExprNode::base::(span.clone()), checked_type: Type::null(), + virtual_device: ObjectRef::null(), } } } From dd96a0f2e40c78eb4b6ba9ee61ebe50ba09a2f1e Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Thu, 9 Dec 2021 12:18:16 -0800 Subject: [PATCH 5/5] flaky?