From 0fe8291907a097996bb041464eecba5870f51ba3 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Thu, 2 Dec 2021 15:49:04 -0800 Subject: [PATCH 01/13] Add virtual device as a first class field to Relay nodes --- include/tvm/ir/expr.h | 27 ++++++++++ include/tvm/relay/expr.h | 102 +++++++++++++++++++++-------------- include/tvm/relay/function.h | 3 ++ src/relay/ir/expr.cc | 51 ++++++++++++++---- src/relay/ir/function.cc | 5 +- 5 files changed, 138 insertions(+), 50 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b910d32ceca4..310d20812ade 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -39,6 +39,9 @@ namespace tvm { using tvm::runtime::String; +// Forward-declare SEScope to avoid circular imports. +class SEScope; + /*! * \brief Base type of all the expressions. * \sa Expr @@ -165,6 +168,30 @@ class RelayExprNode : public BaseExprNode { template inline const TTypeNode* type_as() const; + /*! + * \brief The virtual device (SEScope) for this node (the result of device planning). + * For first-order expressions (non functions), this describes where the result of evaluating the + * expression should be stored. Note that currently, all composite first-order values (tuples, + * references, ADTs) must be stored on the same virtual device. This means that it is not possible + * to store two tuple fields on different devices, so we only need one virtual device for these + * types. + * + * For expressions that have the function type, the virtual device describes where the result of + * the call to the function or closure is stored (instead of where the function itself is stored). + * The SEScope's Target field describes how the body of the function should be compiled. + * + * \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular + * import. We can forward-declare the SEScope type for the getter function, but not for the field + * itself. + */ + mutable ObjectRef virtual_device_; + + /*! + * \return The virtual device (SEScope). + * If the virtual device is not defined, returns SEScope::FullyUnconstrained(). + */ + SEScope virtual_device() const; + static constexpr const char* _type_key = "RelayExpr"; static constexpr const uint32_t _type_child_slots = 22; TVM_DECLARE_BASE_OBJECT_INFO(RelayExprNode, BaseExprNode); diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index f57b2d1a1952..0f45970c9ea1 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -151,10 +152,14 @@ class Tuple : public Expr { * \param tuple The tuple to copy * \param opt_fields The (optional) fields for the copied tuple. If none, ret_tuple->fields = * tuple->fields. - * \param opt_span The (optional) span for the copied tuple. If none, ret_tuple->span = tuple->span. + * \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none, + * ret_tuple->virtual_device = tuple->virtual_device. + * \param opt_span The (optional) span for the copied tuple. If none, + * ret_tuple->span = tuple->span. */ Tuple WithFields(Tuple tuple, Optional> opt_fields = Optional>(), - Optional opt_span = Optional(nullptr)); + Optional opt_virtual_device = Optional(), + Optional opt_span = Optional()); /*! * \brief Local variables used in the let expression. @@ -240,14 +245,16 @@ class Var : public Expr { * \param opt_vid The (optional) vid for the copied var. If none, ret_var->vid = var->vid. * \param opt_type_annotation The (optional) type_annotation for the copied var. If none, * ret_var->type_annotation = var->type_annotation. - * \param opt_span The (optional) span for the copied var. If none, ret_var->span = var->span. - * \return If all properties are null or the same as the property in the input var - * (i.e., opt_vid is null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise, - * we return a copy of call with the different fields overwritten. (i.e., if - * opt_vid.value() != var->vid, then ret_var->vid = opt_.value()). + * \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none, + * ret_tuple->virtual_device = tuple->virtual_device. \param opt_span The (optional) span for the + * copied var. If none, ret_var->span = var->span. \return If all properties are null or the same as + * the property in the input var (i.e., opt_vid is null or opt_vid.value() == var->vid, etc.), then + * we return var. Otherwise, we return a copy of call with the different fields overwritten. (i.e., + * if opt_vid.value() != var->vid, then ret_var->vid = opt_.value()). */ Var WithFields(Var var, Optional opt_vid = Optional(), Optional opt_type_annotation = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! @@ -362,16 +369,18 @@ class Call : public Expr { * call->attrs. * \param opt_type_args The (optional) type args for the copied call. If none, * ret_call->type_args = call->type_args. - * \param opt_span The (optional) span for the copied call. If none, ret_call->span = call->span. - * \return If all properties are null or the same as the property in the input call - * (i.e., opt_op is null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we - * return a copy of call with the different fields overwritten. (i.e., if opt_op.value() != - * call->op, then ret_call->op = opt_op.value()). + * \param opt_virtual_device The (optional) virtual_device for the copied call. If none, + * ret_call->virtual_device = call->virtual_device. \param opt_span The (optional) span for the + * copied call. If none, ret_call->span = call->span. \return If all properties are null or the same + * as the property in the input call (i.e., opt_op is null or opt_op.value() == call->op, etc.), + * then we return call. Otherwise, we return a copy of call with the different fields overwritten. + * (i.e., if opt_op.value() != call->op, then ret_call->op = opt_op.value()). */ Call WithFields(Call call, Optional opt_op = Optional(), Optional> opt_args = Optional>(), Optional opt_attrs = Optional(), Optional> opt_type_args = Optional>(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! @@ -456,15 +465,17 @@ class Let : public Expr { * \param opt_var The (optional) var for the copied let. If none, ret_let->op = let->op. * \param opt_value The (optional) value for the copied let. If none, ret_let->args = let->args. * \param opt_body The (optional) body for the copied let. If none, ret_let->attrs = let->attrs. - * \param opt_span The (optional) span for the copied let. If none, ret_let->span = let->span. - * \return If all properties are null or the same as the property in the input let (i.e., opt_var is - * null or opt_var.value() == let->var, etc.), then we return let. Otherwise, we return a copy of - * let with the different fields overwritten. (i.e., if opt_var.value() != let->var, then - * ret_let->var = opt_var.value()). + * \param opt_virtual_device The (optional) virtual_device for the copied let. If none, + * ret_let->virtual_device = let->virtual_device. \param opt_span The (optional) span for the copied + * let. If none, ret_let->span = let->span. \return If all properties are null or the same as the + * property in the input let (i.e., opt_var is null or opt_var.value() == let->var, etc.), then we + * return let. Otherwise, we return a copy of let with the different fields overwritten. (i.e., if + * opt_var.value() != let->var, then ret_let->var = opt_var.value()). */ Let WithFields(Let let, Optional opt_var = Optional(), Optional opt_value = Optional(), Optional opt_body = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! @@ -539,17 +550,19 @@ class If : public Expr { * ret_if->true_branch = ret_if->false_branch. * \param opt_false_branch The (optional) false_branch * for the copied if_expr. If none, ret_if->false_branch = if_expr->false_branch. - * \param opt_span - * The (optional) span for the copied if_expr. If none, ret_if->span = if_expr->span. - * \return If all - * properties are null or the same as the property in the input if_expr (i.e., opt_cond is null or - * opt_cond.value() == if_expr->cond, etc.), then we return if_expr. Otherwise, we return a copy of - * if_expr with the different fields overwritten. (i.e., if opt_cond.value() != if_expr->cond, then - * ret_if->cond = opt_cond.value()). + * \param opt_virtual_device The (optional) virtual_device for the copied if_expr. If none, + * ret_if->virtual_device = if_expr->virtual_device. + * \param opt_span The (optional) span for the copied if_expr. If none, + * ret_if->span = if_expr->span. + * \return If all properties are null or the same as the property in + * the input if_expr (i.e., opt_cond is null or opt_cond.value() == if_expr->cond, etc.), then we + * return if_expr. Otherwise, we return a copy of if_expr with the different fields overwritten. + * (i.e., if opt_cond.value() != if_expr->cond, then ret_if->cond = opt_cond.value()). */ If WithFields(If if_expr, Optional opt_cond = Optional(), Optional opt_true_branch = Optional(), Optional opt_false_branch = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! \brief Get index-th field out of a tuple. */ @@ -603,8 +616,9 @@ class TupleGetItem : public Expr { * ret_tuple_get_item->tuple = tuple_get_item->tuple. * \param opt_index The (optional) index for the copied tuple_get_item. If none, * ret_tuple_get_item->index = tuple_get_item->index. - * \param - * opt_span The (optional) span for the copied tuple_get_item. If none, + * \param opt_virtual_device The (optional) virtual_device for the copied tuple_get_item. + * If none, ret_tuple_get_item->virtual_device = tuple_get_item->virtual_device. + * \param opt_span The (optional) span for the copied tuple_get_item. If none, * ret_tuple_get_item->span = tuple_get_item->span. * \return If all properties are null or the same as the property in the input tuple_get_item * (i.e., opt_tuple is null or opt_tuple.value() == tuple_get_item->tuple, etc.), then we return @@ -614,6 +628,7 @@ class TupleGetItem : public Expr { */ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), Optional opt_index = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! \brief Create a new Reference out of initial value. */ @@ -663,6 +678,8 @@ class RefCreate : public Expr { * \param ref_create The ref_create to copy. * \param opt_value The (optional) value for the copied ref_create. If none, * ret_ref_create->value = ref_create->value. + * \param opt_virtual_device The (optional) virtual_device for the copied ref_create. If none, + * ret_ref_create->virtual_device = ref_create->virtual_device. * \param opt_span The (optional) span for the copied ref_create. If none, * ret_ref_create->span = ref_create->span. * \return If all properties are null or the same as the property in the input ref_create @@ -672,6 +689,7 @@ class RefCreate : public Expr { * ret_ref_create->value = opt_value.value()). */ RefCreate WithFields(RefCreate ref_create, Optional opt_value = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! \brief Get value out of Reference. */ @@ -720,15 +738,17 @@ class RefRead : public Expr { * \param ref_read The ref_read to copy. * \param opt_ref The (optional) ref for the copied ref_read. If none, ret_ref_read->ref = * ref_read->ref. - * \param opt_span - * The (optional) span for the copied ref_read. If none, ret_ref_read->span = ref_read->span. - * \return If all properties are null or the same as the property in the input ref_read - * (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return ref_read. - * Otherwise, we return a copy of ref_read with the different fields overwritten. - * (i.e., if opt_ref.value() != ref_read->ref, then - * ret_ref_read->ref = opt_ref.value()). + * \param opt_virtual_device + * The (optional) virtual_device for the copied ref_read. If none, ret_ref_read->virtual_device = + * ref_read->virtual_device. \param opt_span The (optional) span for the copied ref_read. If none, + * ret_ref_read->span = ref_read->span. \return If all properties are null or the same as the + * property in the input ref_read (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), + * then we return ref_read. Otherwise, we return a copy of ref_read with the different fields + * overwritten. (i.e., if opt_ref.value() != ref_read->ref, then ret_ref_read->ref = + * opt_ref.value()). */ RefRead WithFields(RefRead ref_read, Optional opt_ref = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ @@ -784,16 +804,18 @@ class RefWrite : public Expr { * ret_ref_write->ref = ref_write->ref. * \param opt_value The (optional) value for the copied ref_write. If none, * ret_ref_write->value = ref_write->value. - * \param opt_span - * The (optional) span for the copied ref_write. If none, ret_ref_write->span = ref_write->span. - * \return If all properties are null or the same as the property in the input ref_write - * (i.e., opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write. - * Otherwise, we return a copy of ref_write with the different fields overwritten. - * (i.e., if ref_write.value() != ref_write->ref, then - * ret_ref_write->ref = opt_ref.value()). + * \param opt_virtual_device + * The (optional) virtual_device for the copied ref_write. If none, ret_ref_write->virtual_device = + * ref_write->virtual_device. \param opt_span The (optional) span for the copied ref_write. If none, + * ret_ref_write->span = ref_write->span. \return If all properties are null or the same as the + * property in the input ref_write (i.e., opt_ref is null or opt_ref.value() == ref_write->ref, + * etc.), then we return ref_write. Otherwise, we return a copy of ref_write with the different + * fields overwritten. (i.e., if ref_write.value() != ref_write->ref, then ret_ref_write->ref = + * opt_ref.value()). */ RefWrite WithFields(RefWrite ref_write, Optional opt_ref = Optional(), Optional opt_value = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index b0a867082a09..1b8ed4443456 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -134,6 +134,8 @@ class Function : public BaseFunc { * \param opt_attrs * The (optional) attributes for the copied function. If none, * ret_function->attrs = function->attrs. + * \param opt_virtual_device The (optional) virtual_device for the copied function. If none, + * ret_function->virtual_device = function->virtual_device. * \param opt_span The (optional) span for the copied function. If none, * ret_function->span = function->span. * \return If all properties are null or the same as the property in the input function @@ -146,6 +148,7 @@ Function WithFields(Function function, Optional> opt_params = Optiona Optional opt_ret_type = Optional(), Optional> opt_ty_params = Optional>(), Optional opt_attrs = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /* diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 8998f4e1573d..6b4b2f16ce1e 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -23,8 +23,17 @@ */ #include #include +#include namespace tvm { + +SEScope RelayExprNode::virtual_device() const { + if (virtual_device_.defined()) { + return Downcast(this->virtual_device_); + } + return SEScope::FullyUnconstrained(); +} + namespace relay { using tvm::ReprPrinter; @@ -76,8 +85,10 @@ TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array fields, Span span) { return Tuple(fields, span); }); -Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_span) { +Tuple WithFields(Tuple tuple, Optional> opt_fields, + Optional opt_virtual_device, Optional opt_span) { Array fields = opt_fields.value_or(tuple->fields); + SEScope virtual_device = opt_virtual_device.value_or(tuple->virtual_device()); Span span = opt_span.value_or(tuple->span); bool all_fields_unchanged = true; @@ -93,6 +104,7 @@ Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional o if (!all_fields_unchanged) { TupleNode* cow_tuple_node = tuple.CopyOnWrite(); cow_tuple_node->fields = fields; + cow_tuple_node->virtual_device_ = virtual_device; cow_tuple_node->span = span; } return std::move(tuple); @@ -113,9 +125,10 @@ Var::Var(Id vid, Type type_annotation, Span span) { } Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation, - Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Id vid = opt_vid.value_or(var->vid); Type type_annotation = opt_type_annotation.value_or(var->type_annotation); + SEScope virtual_device = opt_virtual_device.value_or(var->virtual_device()); Span span = opt_span.value_or(var->span); bool unchanged = vid.same_as(var->vid) && type_annotation.same_as(var->type_annotation) && @@ -125,6 +138,7 @@ Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation VarNode* cow_var_node = var.CopyOnWrite(); cow_var_node->vid = vid; cow_var_node->type_annotation = type_annotation; + cow_var_node->virtual_device_ = virtual_device; cow_var_node->span = span; } return std::move(var); @@ -159,11 +173,12 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span s Call WithFields(Call call, Optional opt_op, Optional> opt_args, Optional opt_attrs, Optional> opt_type_args, - Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Expr op = opt_op.value_or(call->op); Array args = opt_args.value_or(call->args); Attrs attrs = opt_attrs.value_or(call->attrs); Array type_args = opt_type_args.value_or(call->type_args); + SEScope virtual_device = opt_virtual_device.value_or(call->virtual_device()); Span span = opt_span.value_or(call->span); bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && span.same_as(call->span); @@ -201,6 +216,7 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args cow_call_node->args = args; cow_call_node->attrs = attrs; cow_call_node->type_args = type_args; + cow_call_node->virtual_device_ = virtual_device; cow_call_node->span = span; } return std::move(call); @@ -230,10 +246,11 @@ Let::Let(Var var, Expr value, Expr body, Span span) { } Let WithFields(Let let, Optional opt_var, Optional opt_value, Optional opt_body, - Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Var var = opt_var.value_or(let->var); Expr value = opt_value.value_or(let->value); Expr body = opt_body.value_or(let->body); + SEScope virtual_device = opt_virtual_device.value_or(let->virtual_device()); Span span = opt_span.value_or(let->span); bool unchanged = var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body) && @@ -244,6 +261,7 @@ Let WithFields(Let let, Optional opt_var, Optional opt_value, Optiona cow_let_node->var = var; cow_let_node->value = value; cow_let_node->body = body; + cow_let_node->virtual_device_ = virtual_device; cow_let_node->span = span; } return std::move(let); @@ -271,10 +289,12 @@ If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { } If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branch, - Optional opt_false_branch, Optional opt_span) { + Optional opt_false_branch, Optional opt_virtual_device, + Optional opt_span) { Expr cond = opt_cond.value_or(if_expr->cond); Expr true_branch = opt_true_branch.value_or(if_expr->true_branch); Expr false_branch = opt_false_branch.value_or(if_expr->false_branch); + SEScope virtual_device = opt_virtual_device.value_or(if_expr->virtual_device()); Span span = opt_span.value_or(if_expr->span); bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) && @@ -285,6 +305,8 @@ If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branc cow_if_node->cond = cond; cow_if_node->true_branch = true_branch; cow_if_node->false_branch = false_branch; + cow_if_node->virtual_device_ = virtual_device; + cow_if_node->span = span; } return std::move(if_expr); } @@ -312,9 +334,11 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { } TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, - Optional opt_index, Optional opt_span) { + Optional opt_index, Optional opt_virtual_device, + Optional opt_span) { Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); Integer index = opt_index.value_or(tuple_get_item->index); + SEScope virtual_device = opt_virtual_device.value_or(tuple->virtual_device()); Span span = opt_span.value_or(tuple_get_item->span); bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && @@ -324,6 +348,7 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, cow_tuple_get_item_node->tuple = tuple; cow_tuple_get_item_node->index = index; cow_tuple_get_item_node->span = span; + cow_tuple_get_item_node->virtual_device_ = virtual_device; } return std::move(tuple_get_item); } @@ -347,14 +372,17 @@ RefCreate::RefCreate(Expr value, Span span) { data_ = std::move(n); } -RefCreate WithFields(RefCreate ref_create, Optional opt_value, Optional opt_span) { +RefCreate WithFields(RefCreate ref_create, Optional opt_value, + Optional opt_virtual_device, Optional opt_span) { Expr value = opt_value.value_or(ref_create->value); + SEScope virtual_device = opt_virtual_device.value_or(ref_create->virtual_device()); Span span = opt_span.value_or(ref_create->span); bool unchanged = value.same_as(ref_create->value) && span.same_as(ref_create->span); if (!unchanged) { RefCreateNode* cow_ref_create_node = ref_create.CopyOnWrite(); cow_ref_create_node->value = value; + cow_ref_create_node->virtual_device_ = virtual_device; cow_ref_create_node->span = span; } return std::move(ref_create); @@ -379,14 +407,17 @@ RefRead::RefRead(Expr ref, Span span) { data_ = std::move(n); } -RefRead WithFields(RefRead ref_read, Optional opt_ref, Optional opt_span) { +RefRead WithFields(RefRead ref_read, Optional opt_ref, Optional opt_virtual_device, + Optional opt_span) { Expr ref = opt_ref.value_or(ref_read->ref); + SEScope virtual_device = opt_virtual_device.value_or(ref_read->virtual_device()); Span span = opt_span.value_or(ref_read->span); bool unchanged = ref.same_as(ref_read->ref) && span.same_as(ref_read->span); if (!unchanged) { RefReadNode* cow_ref_read_node = ref_read.CopyOnWrite(); cow_ref_read_node->ref = ref; + cow_ref_read_node->virtual_device_ = virtual_device; cow_ref_read_node->span = span; } return std::move(ref_read); @@ -411,9 +442,10 @@ RefWrite::RefWrite(Expr ref, Expr value, Span span) { } RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional opt_value, - Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Expr ref = opt_ref.value_or(ref_write->ref); Expr value = opt_value.value_or(ref_write->value); + SEScope virtual_device = opt_virtual_device.value_or(ref_write->virtual_device()); Span span = opt_span.value_or(ref_write->span); bool unchanged = ref.same_as(ref_write->ref) && value.same_as(ref_write->value) && @@ -422,6 +454,7 @@ RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional o RefWriteNode* cow_ref_write_node = ref_write.CopyOnWrite(); cow_ref_write_node->ref = ref; cow_ref_write_node->value = value; + cow_ref_write_node->virtual_device_ = virtual_device; cow_ref_write_node->span = span; } return std::move(ref_write); diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index f24dd6d1fb4f..f2cb02194009 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -42,12 +42,14 @@ Function::Function(tvm::Array params, Expr body, Type ret_type, Function WithFields(Function function, Optional> opt_params, Optional opt_body, Optional opt_ret_type, Optional> opt_ty_params, - Optional opt_attrs, Optional opt_span) { + Optional opt_attrs, Optional opt_virtual_device, + Optional opt_span) { Array params = opt_params.value_or(function->params); Expr body = opt_body.value_or(function->body); Type ret_type = opt_ret_type.value_or(function->ret_type); Array ty_params = opt_ty_params.value_or(function->type_params); DictAttrs attrs = opt_attrs.value_or(function->attrs); + SEScope virtual_device = opt_virtual_device.value_or(function->virtual_device()); Span span = opt_span.value_or(function->span); bool unchanged = body.same_as(function->body) && ret_type.same_as(function->ret_type) && @@ -86,6 +88,7 @@ Function WithFields(Function function, Optional> opt_params, Optional cow_function_node->ret_type = ret_type; cow_function_node->type_params = ty_params; cow_function_node->attrs = attrs; + cow_function_node->virtual_device_ = virtual_device; cow_function_node->span = span; } return std::move(function); From 49c3374dcb66b809e0880e3857f57735a31f2167 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Thu, 2 Dec 2021 15:11:51 -0800 Subject: [PATCH 02/13] Change representation of virtual devices --- include/tvm/ir/function.h | 10 ---------- src/relay/op/memory/on_device.cc | 22 ++++++++++++++-------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index e466cde097ac..64b9ce964ce1 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -190,16 +190,6 @@ constexpr const char* kTarget = "target"; */ constexpr const char* kGlobalSymbol = "global_symbol"; -/*! - * \brief The SEScope which will hold each of the functions parameters. - * - * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but - * may be included as an annotation on user programs. - * - * Type: Array - */ -constexpr const char* kParamSEScopes = "param_se_scopes"; - /*! * \brief The SEScope which will hold the function result. * diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 9541d4122a2f..cdd82d164b4f 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -127,8 +127,14 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { Function FunctionOnDevice(Function function, Array param_se_scopes, SEScope result_se_scope) { - return WithAttrs(std::move(function), {{tvm::attr::kParamSEScopes, std::move(param_se_scopes)}, - {tvm::attr::kResultSEScope, std::move(result_se_scope)}}); + ICHECK(function->params.size() == param_se_scopes.size()) << "ParamSEScopes must be the same size as the function parameters."; + Array new_params; + for (size_t i = 0; i < function->params.size(); i++) { + Var param = function->params[i]; + new_params.push_back(WithFields(std::move(param), {}, {}, std::move(param_se_scopes[i]))); + } + + return WithAttrs(WithFields(std::move(function), std::move(new_params)), {{tvm::attr::kResultSEScope, std::move(result_se_scope)}}); } TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice); @@ -153,14 +159,14 @@ SEScope GetFunctionParamSEScope(const FunctionNode* function_node, size_t i) { ICHECK_LT(i, function_node->params.size()) << "param index " << i << " out of range for function of arity " << function_node->params.size(); - auto opt_array = function_node->GetAttr>(tvm::attr::kParamSEScopes); - if (!opt_array) { - // No annotation. + + // TODO(@electriclilies): Should we still check that all param sescopes are defined here? + SEScope se_scope = function_node->params[i]->virtual_device(); + if (se_scope.defined()) { return SEScope::FullyUnconstrained(); } - ICHECK_EQ(opt_array.value().size(), function_node->params.size()) - << "annotation parameters do not match function arity"; - return opt_array.value()[i]; + + return se_scope; } } // namespace relay From eea3d19408a6d2246cfa406329ab6578721e1c94 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 6 Dec 2021 12:07:56 -0800 Subject: [PATCH 03/13] Function constructor -> WithFields --- src/relay/analysis/extract_fused_functions.cc | 11 +++---- .../contrib/cmsisnn/extract_constants.cc | 10 +++--- .../backend/contrib/cmsisnn/relay_to_tir.cc | 9 ++---- src/relay/backend/contrib/ethosu/codegen.cc | 8 ++--- .../backend/contrib/ethosu/preprocess.cc | 5 ++- .../example_target_hooks/relay_to_tir.cc | 9 ++---- src/relay/backend/te_compiler.cc | 3 +- src/relay/backend/vm/compiler.cc | 12 +++---- src/relay/backend/vm/lambda_lift.cc | 13 +++++--- src/relay/ir/expr_functor.cc | 32 ++++++------------- src/relay/op/memory/on_device.cc | 17 +++------- src/relay/quantize/annotate.cc | 2 +- src/relay/quantize/calibrate.cc | 3 +- src/relay/transforms/annotate_target.cc | 2 +- src/relay/transforms/convert_sparse_conv2d.cc | 4 +-- src/relay/transforms/convert_sparse_dense.cc | 4 +-- src/relay/transforms/de_duplicate.cc | 8 ++--- src/relay/transforms/defunctionalization.cc | 6 ++-- src/relay/transforms/eta_expand.cc | 3 +- src/relay/transforms/first_order_gradient.cc | 4 +-- src/relay/transforms/fuse_ops.cc | 2 ++ src/relay/transforms/higher_order_gradient.cc | 17 +++++----- src/relay/transforms/inline.cc | 7 ++-- src/relay/transforms/partial_eval.cc | 5 ++- src/relay/transforms/partition_graph.cc | 18 ++++------- src/relay/transforms/pass_utils.h | 2 +- src/relay/transforms/simplify_fc_transpose.cc | 4 +-- src/relay/transforms/to_a_normal_form.cc | 3 +- src/relay/transforms/to_cps.cc | 19 ++++++----- 29 files changed, 102 insertions(+), 140 deletions(-) diff --git a/src/relay/analysis/extract_fused_functions.cc b/src/relay/analysis/extract_fused_functions.cc index e76b54e2d0b7..13c189ca5a3a 100644 --- a/src/relay/analysis/extract_fused_functions.cc +++ b/src/relay/analysis/extract_fused_functions.cc @@ -52,15 +52,14 @@ class FusedFunctionExtractorWrapper : private ExprVisitor { // have the desired equals property Map functions; - void VisitExpr_(const FunctionNode* n) final { - if (n->HasNonzeroAttr(attr::kPrimitive)) { + void VisitExpr_(const FunctionNode* func_node) final { + if (func_node->HasNonzeroAttr(attr::kPrimitive)) { // Add function to functions, keyed by function hash string - Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs); - size_t hash_ = tvm::StructuralHash()(func); - this->functions.Set(std::to_string(hash_), func); + size_t hash_ = tvm::StructuralHash()(GetRef(func_node)); + this->functions.Set(std::to_string(hash_), GetRef(func_node)); } - ExprVisitor::VisitExpr_(n); + ExprVisitor::VisitExpr_(func_node); } }; diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc index 5ed23ad1ad6a..a3c6f8367a02 100644 --- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc @@ -46,15 +46,14 @@ class ExtractConstantsMutator : public MixedModeMutator { private: String gen_var_name() { return "tvm_var_extract_const_" + std::to_string(var_count_++); } - Expr VisitExpr_(const FunctionNode* function) final { - Function func = GetRef(function); + Expr VisitExpr_(const FunctionNode* func_node) final { + Function func = GetRef(func_node); function_to_constants_.Set(func, Array{}); functions_.push_back(func); auto new_body = VisitExpr(func->body); functions_.pop_back(); if (function_to_constants_[func].size()) { - func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_), - func->attrs); + func = WithFields(std::move(func), FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_), func->attrs); } return func; } @@ -145,8 +144,7 @@ IRModule ExtractConstants(const IRModule& mod) { auto new_main_body = extract_constants.VisitExpr(main_func->body); if (!new_main_body.same_as(main_func->body)) { auto main_var = mod->GetGlobalVar("main"); - auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type, - main_func->type_params, main_func->attrs); + Function new_main_func = WithFields(std::move(main_func), main_func->params, new_main_body); mod->Update(main_var, new_main_func); } return mod; diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 668352700805..2197d6afd1ca 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -44,13 +44,8 @@ class RelayToTIRVisitor : public MixedModeMutator { IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); - BaseFunc main = ir_module_->Lookup(main_global_var); - Function main_func = GetRef(main.as()); - - // Copy everything across and mutate the body - Function mutated_main = - Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, - main_func->type_params, main_func->attrs, main_func->span); + Function main = Downcast(ir_module_->Lookup(main_global_var)); + Function mutated_main = WithFields(std::move(main), main->params, VisitExpr(main->body)); ir_module_->Update(main_global_var, mutated_main); diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index d618a4971189..a26ae5edd4d7 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -56,12 +56,8 @@ class RelayToTIRMutator : public MixedModeMutator { IRModule operator()() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); - Function main_func = Downcast(ir_module_->Lookup(main_global_var)); - - // Copy everything across and mutate the body - Function mutated_main = - Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, - main_func->type_params, main_func->attrs, main_func->span); + Function main = Downcast(ir_module_->Lookup(main_global_var)); + Function mutated_main = WithFields(std::move(main), main->params, VisitExpr(main->body)); ir_module_->Update(main_global_var, mutated_main); ir_module_ = WithAttr(ir_module_, "device_contexts", device_contexts_); diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc b/src/relay/backend/contrib/ethosu/preprocess.cc index 571a56ad97c0..446524aed964 100644 --- a/src/relay/backend/contrib/ethosu/preprocess.cc +++ b/src/relay/backend/contrib/ethosu/preprocess.cc @@ -177,11 +177,10 @@ class ExternalFuncIOHandler : public ExprRewriter { reshaped_outputs.push_back(CreateFlattenTensor(out)); } auto concat_out = CreateConcatTensor(reshaped_outputs); - auto f = Function(params, concat_out, concat_out->checked_type_, {}, func->attrs); + Function f = WithFields(std::move(func), std::move(params), std::move(concat_out), std::move(concat_out->checked_type_), Array() /* erase type params */); return InferType(f, this->module_); } else { - auto f = - Function(params, core_compute_expr, core_compute_expr->checked_type_, {}, func->attrs); + Function f = WithFields(std::move(func), std::move(params), std::move(core_compute_expr), std::move(core_compute_expr->checked_type_), Array() /* erase type params */); return InferType(f, this->module_); } } diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index c41399e314ef..c544ef683bb1 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -43,13 +43,8 @@ class ConvertAddToSubtract : public MixedModeMutator { IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); - BaseFunc main = ir_module_->Lookup(main_global_var); - Function main_func = GetRef(main.as()); - - // Copy everything across and mutate the body - Function mutated_main = - Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, - main_func->type_params, main_func->attrs, main_func->span); + Function main = GetRef(ir_module_->Lookup(main_global_var).as()); + Function mutated_main = WithFields(std::move(main), main->params, VisitExpr(main->body)); ir_module_->Update(main_global_var, mutated_main); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index b47bc401b37f..2484f1eda8cd 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -164,8 +164,7 @@ class TECompilerImpl : public TECompilerNode { for (const auto& kv2 : kv1.second->cached_func->funcs->functions) { if (const auto* function_node = kv2.second.as()) { // Abandon the existing function annotations. - Function function(function_node->params, function_node->body, function_node->ret_type, - function_node->type_params, /*attrs=*/{}, function_node->span); + Function function = WithFields(GetRef(function_node), {}, {}, {}, {}, /* erase attributes */ DictAttrs()); // Mark function as 'extern' using the "ExternalSymbol" attribute. function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint); module->Add(kv2.first, function); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b7c0999ecc72..d14b5cd1c949 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -251,24 +251,24 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { // fn(closure args, lifter function args) { body } // Do that flattening on-the-fly here. Function inner_func = Downcast(func->body); - std::vector params; + Array params; std::vector param_se_scopes; params.reserve(func->params.size() + inner_func->params.size()); param_se_scopes.reserve(func->params.size() + inner_func->params.size()); param_device_indexes.reserve(func->params.size() + inner_func->params.size()); for (size_t i = 0; i < func->params.size(); ++i) { - params.emplace_back(func->params[i]); + params.push_back(func->params[i]); SEScope param_se_scope = GetFunctionParamSEScope(func.get(), i); param_se_scopes.push_back(param_se_scope); param_device_indexes.push_back(GetDeviceIndex(param_se_scope)); } for (size_t i = 0; i < inner_func->params.size(); ++i) { - params.emplace_back(inner_func->params[i]); + params.push_back(inner_func->params[i]); SEScope param_se_scope = GetFunctionParamSEScope(inner_func.get(), i); param_se_scopes.push_back(param_se_scope); param_device_indexes.push_back(GetDeviceIndex(param_se_scope)); } - std::vector type_params; + Array type_params; type_params.reserve(func->type_params.size() + inner_func->type_params.size()); for (const auto& tyvar : func->type_params) { type_params.push_back(tyvar); @@ -276,8 +276,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { for (const auto& tyvar : inner_func->type_params) { type_params.push_back(tyvar); } - Function flattened_func = Function(params, inner_func->body, inner_func->ret_type, - type_params, func->attrs, func->span); + Function flattened_func = WithFields(std::move(func), std::move(params), inner_func->body, inner_func->ret_type, + std::move(type_params)); VisitExpr(MaybeFunctionOnDevice(flattened_func, param_se_scopes, GetFunctionResultSEScope(inner_func.get()))); } else { diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index ffd0e466eb24..ea297b2bf8b9 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -102,8 +102,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { if (function_nesting() == 1) { // We don't need to lift global functions. - return Function(func_node->params, VisitExpr(func_node->body), func_node->ret_type, - func_node->type_params, func_node->attrs, func_node->span); + return WithFields(GetRef(func_node), func_node->params, VisitExpr(func_node->body)); } auto name = GenerateName(func); @@ -172,8 +171,13 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { // code for the closure. Function lifted_func; if (captured_vars.empty() && free_type_vars.empty()) { + // Intentionally creating a fresh copy, since the lifted function is being bound in the global + // environment, not the local environment. + // TODO(@electriclilies): need to propagate virtual_device + // COME BACk lifted_func = Function(body->params, body->body, body->ret_type, body->type_params, body->attrs, body->span); + lifted_func->virtual_device_ = body->virtual_device(); } else { // When a closure is locally bound in a program, we have its full type information // avalible to us. @@ -188,13 +192,14 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { // construct the "closure" function with fully annotated arguments, no longer relying // on type inference. size_t before_arity = body->params.size(); - auto rebound_body = Function(func->params, Bind(body->body, rebinding_map), func->ret_type, - func->type_params, func->attrs, func->span); + auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map)); size_t after_arity = rebound_body->params.size(); CHECK_EQ(before_arity, after_arity); + // COME BACK lifted_func = Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(), free_type_vars, /*attrs=*/{}, func->span); + // COME BACK lifted_func = MaybeFunctionOnDevice(lifted_func, captured_var_se_scopes, result_se_scope); lifted_func = MarkClosure(lifted_func); } diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index a08de39d0abb..4b00bfec90e7 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -477,42 +477,28 @@ class ExprBinder : public MixedModeMutator, PatternMutator { Expr Bind(const Expr& expr, const tvm::Map& args_map) { if (const FunctionNode* func = expr.as()) { Expr new_body = ExprBinder(args_map).VisitExpr(func->body); - Array new_params; - std::vector new_param_se_scopes; - for (size_t i = 0; i < func->params.size(); ++i) { - if (!args_map.count(func->params[i])) { - new_params.push_back(func->params[i]); - new_param_se_scopes.push_back(GetFunctionParamSEScope(func, i)); - } - } - if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { - return expr; - } - auto ret = - Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); - ret = MaybeFunctionOnDevice(ret, new_param_se_scopes, GetFunctionResultSEScope(func)); + + Function ret_func = WithFields(GetRef(func), func->params, std::move(new_body)); + std::unordered_set set; for (const auto& v : FreeVars(expr)) { set.insert(v); } - for (const auto& v : FreeVars(ret)) { + + for (const auto& v : FreeVars(ret_func)) { if (set.count(v) == 0) { - new_params.push_back(v); - if (!GetFunctionResultSEScope(func)->IsFullyUnconstrained()) { + if (!v->virtual_device()->IsFullyConstrained()) { // TODO(mbs): The function has been annotated with a device, which means we are supposed // to be preserving device annotations on every transformation. However there's no // such context for the free vars in args_map. LOG(WARNING) << "introduced free var '" << PrettyPrint(v) << "' into function body but no device is known for it"; + continue; } - new_param_se_scopes.push_back(SEScope::FullyUnconstrained()); } } - ret = - Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); - ret = MaybeFunctionOnDevice(ret, new_param_se_scopes, GetFunctionResultSEScope(func)); - ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); - return std::move(ret); + ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret_func).size()); + return std::move(ret_func); } else { return ExprBinder(args_map).VisitExpr(expr); } diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index cdd82d164b4f..56e411bf78b1 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -133,12 +133,12 @@ Function FunctionOnDevice(Function function, Array param_se_scopes, Var param = function->params[i]; new_params.push_back(WithFields(std::move(param), {}, {}, std::move(param_se_scopes[i]))); } - - return WithAttrs(WithFields(std::move(function), std::move(new_params)), {{tvm::attr::kResultSEScope, std::move(result_se_scope)}}); + return WithFields(std::move(function), std::move(new_params), {}, {}, {}, {}, std::move(result_se_scope)); } TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice); +/* Function MaybeFunctionOnDevice(Function function, Array param_se_scopes, SEScope result_se_scope) { if (std::all_of(param_se_scopes.begin(), param_se_scopes.end(), @@ -148,11 +148,10 @@ Function MaybeFunctionOnDevice(Function function, Array param_se_scopes return function; } return FunctionOnDevice(function, std::move(param_se_scopes), std::move(result_se_scope)); -} +}*/ SEScope GetFunctionResultSEScope(const FunctionNode* function_node) { - auto opt_se_scope = function_node->GetAttr(tvm::attr::kResultSEScope); - return opt_se_scope.value_or(SEScope::FullyUnconstrained()); + return function_node->virtual_device(); } SEScope GetFunctionParamSEScope(const FunctionNode* function_node, size_t i) { @@ -160,13 +159,7 @@ SEScope GetFunctionParamSEScope(const FunctionNode* function_node, size_t i) { << "param index " << i << " out of range for function of arity " << function_node->params.size(); - // TODO(@electriclilies): Should we still check that all param sescopes are defined here? - SEScope se_scope = function_node->params[i]->virtual_device(); - if (se_scope.defined()) { - return SEScope::FullyUnconstrained(); - } - - return se_scope; + return function_node->params[i]->virtual_device(); } } // namespace relay diff --git a/src/relay/quantize/annotate.cc b/src/relay/quantize/annotate.cc index 3def616e9423..c8e378d544ba 100644 --- a/src/relay/quantize/annotate.cc +++ b/src/relay/quantize/annotate.cc @@ -98,7 +98,7 @@ Pass QuantizeAnnotate() { for (const auto& x : FreeVars(func)) { new_params.push_back(x); } - return Function(new_params, func->body, func->ret_type, func->type_params, func->attrs); + return WithFields(std::move(func), std::move(new_params)); }; return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {}); } diff --git a/src/relay/quantize/calibrate.cc b/src/relay/quantize/calibrate.cc index 0ac445295496..5bc64bf7368a 100644 --- a/src/relay/quantize/calibrate.cc +++ b/src/relay/quantize/calibrate.cc @@ -152,8 +152,7 @@ class StatsCollector : private ExprMutator { const FunctionNode* func = new_e.as(); ICHECK(func) << "Input shoule be Function"; Expr new_body = Tuple(std::move(profile_data_)); - return Function(FreeVars(new_body), new_body, NullValue(), func->type_params, - func->attrs); + return WithFields(GetRef(func), FreeVars(new_body), std::move(new_body), NullValue()); } private: diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index df1a858f8d0b..9aa63b263a21 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -295,7 +295,7 @@ class AnnotateTargetRewriter : public ExprRewriter { func = Downcast(post); new_body = InsertCompilerEndAndPropogateTarget(func->body); } - return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs); + return WithFields(std::move(func), func->params, std::move(new_body)); } Expr Rewrite_(const LetNode* op, const Expr& post) override { diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index 3f2c25e988f9..dc71961b2172 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -292,12 +292,12 @@ Pass Conv2dToSparse(const Array& weight_name, const Array(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); Array sparse_params = FreeVars(f0); - auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); + auto f1 = WithFields(std::move(f0), std::move(sparse_params)); Array params = FreeVars(f1); for (const auto& var : sparse_params) { params.push_back(var); } - return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); + return WithFields(std::move(f1), std::move(params)); }; return CreateFunctionPass(pass_func, 4, "Conv2dToSparse", {"DeadCodeElimination"}); } diff --git a/src/relay/transforms/convert_sparse_dense.cc b/src/relay/transforms/convert_sparse_dense.cc index 26a4d487196d..4bd456a74a3f 100644 --- a/src/relay/transforms/convert_sparse_dense.cc +++ b/src/relay/transforms/convert_sparse_dense.cc @@ -135,12 +135,12 @@ Pass DenseToSparse(const Array& weight_name, // Remove FreeVar warnings auto f0 = Downcast(DenseToSparse(f, weight_name, weight_shape)); Array sparse_params = FreeVars(f0); - auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); + auto f1 = WithFields(std::move(f0), std::move(sparse_params)); Array params = FreeVars(f1); for (const auto& var : sparse_params) { params.push_back(var); } - return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); + return WithFields(std::move(f1), std::move(params)); }; return CreateFunctionPass(pass_func, 4, "DenseToSparse", {"DeadCodeElimination"}); } diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index 2fd88736bf31..d9486ce9bf95 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -81,16 +81,16 @@ Expr DeDup(const Expr& e) { Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; } - Expr VisitExpr_(const FunctionNode* op) final { + Expr VisitExpr_(const FunctionNode* func_node) final { tvm::Array type_params; - for (const TypeVar& type_param : op->type_params) { + for (const TypeVar& type_param : func_node->type_params) { type_params.push_back(Fresh(type_param)); } tvm::Array params; - for (const Var& param : op->params) { + for (const Var& param : func_node->params) { params.push_back(Fresh(param)); } - return Function(params, VisitExpr(op->body), VisitType(op->ret_type), type_params, op->attrs); + return WithFields(GetRef(func_node), std::move(params), VisitExpr(func_node->body), VisitType(func_node->ret_type), std::move(type_params)); } Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); } diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index 5255a672a856..c3cd994e44dd 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -283,7 +283,7 @@ class DefuncMutator : public ExprMutator { auto apply_gv = GetApplyFunction(ft); auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map)); - AddApplyCase(apply_gv, ft, c, Function(fn->params, body, fn->ret_type, fn->type_params), + AddApplyCase(apply_gv, ft, c, WithFields(GetRef(fn), fn->params, std::move(body)), pattern_vars); return Call(c, call_args); @@ -380,7 +380,7 @@ class DefuncMutator : public ExprMutator { map.Set(f->type_params[i], type_args[i]); } // copy with typevars removed - auto copy = TypeSubst(Function(f->params, f->body, f->ret_type, {}), map); + auto copy = TypeSubst(WithFields(std::move(f), {}, {}, {}, /* erase typeparams */ Array()), map); return Downcast(copy); } @@ -410,7 +410,7 @@ class DefuncMutator : public ExprMutator { } auto bind = Downcast(Bind(f, var_bind_map)); - return Function(params, this->VisitExpr(bind->body), bind->ret_type, {}); + return WithFields(std::move(bind), std::move(params), this->VisitExpr(bind->body)); } }; diff --git a/src/relay/transforms/eta_expand.cc b/src/relay/transforms/eta_expand.cc index 4023c9dafef4..9bdd04231a5f 100644 --- a/src/relay/transforms/eta_expand.cc +++ b/src/relay/transforms/eta_expand.cc @@ -129,8 +129,7 @@ class EtaExpander : public ExprMutator { params.push_back(var); args.push_back(var); } - - return Function(args, Call(gvar, params), func->ret_type, func->type_params); + return WithFields(std::move(func), std::move(args), Call(gvar, params)); } else { return std::move(gvar); } diff --git a/src/relay/transforms/first_order_gradient.cc b/src/relay/transforms/first_order_gradient.cc index 9408d16d87e9..9c1d4baa0617 100644 --- a/src/relay/transforms/first_order_gradient.cc +++ b/src/relay/transforms/first_order_gradient.cc @@ -307,8 +307,8 @@ Pass FirstOrderGradient() { }); return Pair(res.forward, grad_tuple); }); - ad_mod->Update(pr.first, - Function(func->params, body, GradRetType(GetRef(func)), {})); + ad_mod->Update(pr.first, + WithFields(GetRef(func), func->params, std::move(body), GradRetType(GetRef(func)))); } return ad_mod; diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index f2fc0af4f9c1..a64e6282266d 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -951,6 +951,8 @@ class FuseMutator : private MixedModeMutator { Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) { // Quickly check special properties of the fused function. // A pass to check if the fused op contains only reshape ops. + // TODO(@electriclilies): This won't propagate the rest of the function information correctly. + // We probably should pass the whole function in here. class CheckReshapeOnly : public ExprVisitor { public: void VisitExpr_(const CallNode* cn) final { diff --git a/src/relay/transforms/higher_order_gradient.cc b/src/relay/transforms/higher_order_gradient.cc index 202275626d5d..68a34c28c967 100644 --- a/src/relay/transforms/higher_order_gradient.cc +++ b/src/relay/transforms/higher_order_gradient.cc @@ -341,28 +341,27 @@ struct ReverseAD : ExprMutator { GlobalVar gv(op->name_hint + "_grad"); (*ad_gvars)[orig_gv] = gv; Function orig_f = Downcast(DeDup(mod.value()->Lookup(orig_gv))); - std::vector params; + Array params; for (const auto& p : orig_f->params) { params.push_back(Downcast(VisitExpr(p))); } params.push_back(bp); - Expr body = VisitExpr(orig_f->body); - Function f(params, body, VisitType(orig_f->ret_type), orig_f->type_params, orig_f->attrs); + Function f = WithFields(std::move(orig_f), std::move(params), VisitExpr(orig_f->body), VisitType(orig_f->ret_type)); std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl; mod.value()->Add(gv, f); } return ad_gvars->at(orig_gv); } - Expr VisitExpr_(const FunctionNode* op) final { - std::vector params; - for (const auto& var : op->params) { + Expr VisitExpr_(const FunctionNode* func_node) final { + Array params; + for (const auto& var : func_node->params) { params.push_back(Downcast(VisitExpr(var))); } auto new_bp = Var("bp", bpt); params.push_back(new_bp); - return Function(params, ReverseAD(mod, new_bp, ad_vars, ad_gvars)(op->body), - VisitType(op->ret_type), op->type_params, op->attrs); + return WithFields(GetRef(func_node), std::move(params), ReverseAD(mod, new_bp, ad_vars, ad_gvars)(func_node->body), + VisitType(func_node->ret_type)); } Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; } @@ -456,7 +455,7 @@ Expr Gradient(const Expr& re, const Optional& mod) { }; return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret)); }); - auto ret = Function(f->params, body, GradRetType(GetRef(f)), {}); + Function ret = WithFields(GetRef(f), f->params, std::move(body), GradRetType(GetRef(f))); CheckFeature(ret, FeatureSet::All() - fGraph); return std::move(ret); } diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index f1492b9f1258..9ddfb53dd3c9 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -91,8 +91,7 @@ class Inliner : ExprMutator { } Function Inline(const Function& func) { - return Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); + return WithFields(std::move(func), func->params, VisitExpr(func->body)); } private: @@ -131,7 +130,9 @@ class Inliner : ExprMutator { const auto* fn = base_func.as(); ICHECK(fn) << "Expected to work on a Relay function."; - auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs); + // TODO(@electriclilies): If Function is a COW node, then if it gets written to we shouldn't have any sharing, right? + // So we don't need to reconstruct? + Function func = WithFields(GetRef(fn)); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. if (!func->GetAttr(attr::kCompiler).defined() && diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 8f5e9e146d54..bb2a6fc97b46 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -827,7 +827,7 @@ class PartialEvaluator : public ExprFunctor Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) { return store_.Extend([&]() { store_.Invalidate(); - return Function(func->params, LetList::With([&](LetList* ll) { + return WithFields(std::move(func), func->params, LetList::With([&](LetList* ll) { std::vector pv; for (const auto& v : func->params) { pv.push_back(NoStatic(v)); @@ -837,8 +837,7 @@ class PartialEvaluator : public ExprFunctor type_args.push_back(tp); } return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; - }), - func->ret_type, func->type_params, func->attrs); + })); }); } diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index 4a21bc87411b..d82597adb6d8 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -212,9 +212,7 @@ class Partitioner : public MixedModeMutator { auto glob_funcs = module_->functions; for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); + Function func = WithFields(GetRef(fn), func->params, VisitExpr(fn->body)); module_->Update(pair.first, func); module_ = transform::InferType()(module_); } @@ -425,8 +423,8 @@ IRModule RemoveDefaultAnnotations(IRModule module) { auto func = GetRef(fn); DefaultRemover remover; auto removed = PostOrderRewrite(func->body, &remover); - func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs); - module->Update(pair.first, func); + func = WithFields(GetRef(fn), func->params, std::move(removed)); + module->Update(pair.first, GetRef(fn)); module = relay::transform::InferType()(module); } } @@ -478,10 +476,10 @@ IRModule FlattenTupleOutputs(IRModule module) { module.CopyOnWrite(); for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); + Function func = GetRef(fn); TupleOutFlattener to_flattener; auto removed = PostOrderRewrite(func->body, &to_flattener); - func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs); + func = WithFields(GetRef(fn), func->params, std::move(removed)); module->Update(pair.first, func); module = relay::transform::InferType()(module); } @@ -523,12 +521,10 @@ class NameMangleExtFuncs : public MixedModeMutator { auto new_dict = func->attrs->dict; new_dict.Set(tvm::attr::kGlobalSymbol, String(relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint)))); - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - DictAttrs(new_dict)); + func = WithFields(func, func->params, VisitExpr(func->body), func->ret_type, func->type_params, DictAttrs(new_dict)); new_module->Add(mangled_gvars_[pair.first->name_hint], func); } else { - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); + func = WithFields(func, func->params, VisitExpr(func->body)); new_module->Add(pair.first, func); } } diff --git a/src/relay/transforms/pass_utils.h b/src/relay/transforms/pass_utils.h index 317ac17f83c8..b14a93f02b55 100644 --- a/src/relay/transforms/pass_utils.h +++ b/src/relay/transforms/pass_utils.h @@ -106,7 +106,7 @@ bool IsDataDependent(const CallNode* call); */ inline Expr TransformF(const std::function& func, const Expr& e) { if (const FunctionNode* f = e.as()) { - return Function(f->params, func(f->body), f->ret_type, f->type_params, f->attrs); + return WithFields(GetRef(f), f->params, func(f->body)); } else { return func(e); } diff --git a/src/relay/transforms/simplify_fc_transpose.cc b/src/relay/transforms/simplify_fc_transpose.cc index b5090e7e6fe4..d5ceb2986163 100644 --- a/src/relay/transforms/simplify_fc_transpose.cc +++ b/src/relay/transforms/simplify_fc_transpose.cc @@ -128,12 +128,12 @@ Pass SimplifyFCTranspose(const Array& target_weights) { // Remove FreeVar warning auto f0 = Downcast(SimplifyFCTranspose(f, target_weights)); Array wt_params = FreeVars(f0); - auto f1 = Function(wt_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); + auto f1 = WithFields(std::move(f0), wt_params); Array params = FreeVars(f1); for (const auto& var : wt_params) { params.push_back(var); } - return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); + return WithFields(std::move(f1), params); }; return CreateFunctionPass(pass_func, 4, "SimplifyFCTranspose", {"DeadCodeElimination"}); } diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index f958a600551e..04c7418f35ff 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -298,8 +298,7 @@ class Fill : ExprFunctor, private transform::Lexi PushBoundVar(f->params[i], GetFunctionParamSEScope(f, i)); } EnterFunctionBody(); - ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type, - f->type_params, f->attrs); + ret = WithFields(GetRef(f), f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body))); // We are done with this function. ExitFunctionBody(); for (size_t i = 0; i < f->params.size(); ++i) { diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index 0f889cd6ff7f..ad58123cadbe 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -272,8 +272,8 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, new_params.push_back(remap(v)); } new_params.push_back(k); - return Function(new_params, mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }), - answer, f->type_params, f->attrs); + return WithFields(std::move(f), std::move(new_params), mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }), + std::move(answer)); } Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { @@ -299,7 +299,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { Function ret = ToCPS(f, m, cm, &var, answer); auto new_type_params = ret->type_params; new_type_params.push_back(answer); - return Function(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs); + return WithFields(std::move(ret), ret->params, ret->body, ret->ret_type, std::move(new_type_params)); } Function ToCPS(const Function& f, const IRModule& m) { @@ -311,15 +311,17 @@ Function ToCPS(const Function& f, const IRModule& m) { Function UnCPS(const Function& f) { CheckFeature(f, FeatureSet::All() - fGraph); ICHECK_GT(f->params.size(), 0); - std::vector new_params; + Array new_params; for (const auto& p : f->params) { - new_params.push_back(Var(p->name_hint(), p->checked_type())); + // TODO(@electriclilies): Not sure if this is correct, it was copying before, + // but seems like we just need to make a copy to pop so should be fine? + new_params.push_back(WithFields(std::move(p))); } auto cont_type = Downcast(new_params.back()->type_annotation); new_params.pop_back(); ICHECK_EQ(cont_type->arg_types.size(), 1); auto new_ret_type = Type(cont_type->arg_types[0]); - std::vector new_type_params; + Array new_type_params; for (const auto& tp : f->type_params) { new_type_params.push_back(TypeVar(tp->name_hint, tp->kind)); } @@ -339,8 +341,9 @@ Function UnCPS(const Function& f) { type_args.push_back(tp); } type_args.push_back(new_ret_type); - return Function(new_params, Call(f, args, {}, type_args), new_ret_type, new_type_params, - f->attrs); + Call call = Call(f, args, {}, type_args); + // How do I fix this? + return WithFields(f, std::move(new_params), call, std::move(new_ret_type), std::move(new_type_params)); } TVM_REGISTER_GLOBAL("relay._transform.to_cps") From fe1dc5ca389d585256d7e1dce7f60e5833d37730 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 6 Dec 2021 15:17:53 -0800 Subject: [PATCH 04/13] Remove GetFunctionParamSEScope and GetFunctionResultSEScope helpers --- include/tvm/ir/function.h | 10 ------- src/relay/backend/vm/compiler.cc | 27 +++++++------------ src/relay/backend/vm/lambda_lift.cc | 16 +++++------ src/relay/op/memory/on_device.cc | 23 ---------------- src/relay/op/memory/on_device.h | 13 --------- src/relay/transforms/device_aware_visitors.cc | 24 ++++++++--------- src/relay/transforms/device_aware_visitors.h | 10 +++---- src/relay/transforms/device_planner.cc | 9 +++---- src/relay/transforms/to_a_normal_form.cc | 10 +++---- 9 files changed, 43 insertions(+), 99 deletions(-) diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 64b9ce964ce1..1493544e7324 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -190,16 +190,6 @@ constexpr const char* kTarget = "target"; */ constexpr const char* kGlobalSymbol = "global_symbol"; -/*! - * \brief The SEScope which will hold the function result. - * - * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but - * may be included as an annotation on user programs. - * - * Type: SEScope - */ -constexpr const char* kResultSEScope = "result_se_scope"; - } // namespace attr } // namespace tvm #endif // TVM_IR_FUNCTION_H_ diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index d14b5cd1c949..4bae45ec703e 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -252,21 +252,15 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { // Do that flattening on-the-fly here. Function inner_func = Downcast(func->body); Array params; - std::vector param_se_scopes; params.reserve(func->params.size() + inner_func->params.size()); - param_se_scopes.reserve(func->params.size() + inner_func->params.size()); param_device_indexes.reserve(func->params.size() + inner_func->params.size()); - for (size_t i = 0; i < func->params.size(); ++i) { - params.push_back(func->params[i]); - SEScope param_se_scope = GetFunctionParamSEScope(func.get(), i); - param_se_scopes.push_back(param_se_scope); - param_device_indexes.push_back(GetDeviceIndex(param_se_scope)); + for (const Var param: func->params) { + params.push_back(param); + param_device_indexes.push_back(GetDeviceIndex(param->virtual_device())); } - for (size_t i = 0; i < inner_func->params.size(); ++i) { - params.push_back(inner_func->params[i]); - SEScope param_se_scope = GetFunctionParamSEScope(inner_func.get(), i); - param_se_scopes.push_back(param_se_scope); - param_device_indexes.push_back(GetDeviceIndex(param_se_scope)); + for (const Var inner_param : inner_func->params) { + params.push_back(inner_param); + param_device_indexes.push_back(GetDeviceIndex(inner_param->virtual_device())); } Array type_params; type_params.reserve(func->type_params.size() + inner_func->type_params.size()); @@ -277,13 +271,12 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { type_params.push_back(tyvar); } Function flattened_func = WithFields(std::move(func), std::move(params), inner_func->body, inner_func->ret_type, - std::move(type_params)); - VisitExpr(MaybeFunctionOnDevice(flattened_func, param_se_scopes, - GetFunctionResultSEScope(inner_func.get()))); + std::move(type_params), {}, inner_func->virtual_device()); + } else { param_device_indexes.reserve(func->params.size()); - for (size_t i = 0; i < func->params.size(); ++i) { - param_device_indexes.push_back(GetDeviceIndex(GetFunctionParamSEScope(func.get(), i))); + for (const Var param: func->params) { + param_device_indexes.push_back(GetDeviceIndex(param->virtual_device())); } VisitExpr(func); } diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index ea297b2bf8b9..673ef0eeeec7 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -111,27 +111,25 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { auto free_type_vars = FreeTypeVars(func, module_); Array captured_vars; - std::vector captured_var_se_scopes; bool recursive = false; for (const auto& var : free_vars) { if (!letrec_.empty() && var == letrec_.back()) { recursive = true; continue; } - captured_vars.push_back(var); - captured_var_se_scopes.push_back(GetSEScope(var)); } // Freshen all the captured vars. Array typed_captured_vars; Map rebinding_map; for (auto free_var : captured_vars) { - auto var = Var(free_var->name_hint(), free_var->checked_type()); + auto var = Var(free_var->name_hint(), free_var->checked_type(), free_var->span); + var->virtual_device_ = free_var->virtual_device(); typed_captured_vars.push_back(var); rebinding_map.Set(free_var, var); } - SEScope result_se_scope = GetSEScope(func_node->body); + SEScope result_se_scope = func_node->body->virtual_device(); if (recursive) { if (!captured_vars.empty()) { @@ -177,7 +175,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { // COME BACk lifted_func = Function(body->params, body->body, body->ret_type, body->type_params, body->attrs, body->span); - lifted_func->virtual_device_ = body->virtual_device(); + lifted_func->virtual_device_ = result_se_scope; } else { // When a closure is locally bound in a program, we have its full type information // avalible to us. @@ -195,12 +193,12 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map)); size_t after_arity = rebound_body->params.size(); CHECK_EQ(before_arity, after_arity); - // COME BACK + lifted_func = Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(), free_type_vars, /*attrs=*/{}, func->span); - // COME BACK - lifted_func = MaybeFunctionOnDevice(lifted_func, captured_var_se_scopes, result_se_scope); + // We have already propagated the se_scopes of the variables + lifted_func->virtual_device_ = result_se_scope; lifted_func = MarkClosure(lifted_func); } diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 56e411bf78b1..aa355bc9b7fd 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -138,29 +138,6 @@ Function FunctionOnDevice(Function function, Array param_se_scopes, TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice); -/* -Function MaybeFunctionOnDevice(Function function, Array param_se_scopes, - SEScope result_se_scope) { - if (std::all_of(param_se_scopes.begin(), param_se_scopes.end(), - [](const SEScope& se_scope) { return se_scope->IsFullyUnconstrained(); }) && - result_se_scope->IsFullyUnconstrained()) { - // Nothing to annotate. - return function; - } - return FunctionOnDevice(function, std::move(param_se_scopes), std::move(result_se_scope)); -}*/ - -SEScope GetFunctionResultSEScope(const FunctionNode* function_node) { - return function_node->virtual_device(); -} - -SEScope GetFunctionParamSEScope(const FunctionNode* function_node, size_t i) { - ICHECK_LT(i, function_node->params.size()) - << "param index " << i << " out of range for function of arity " - << function_node->params.size(); - - return function_node->params[i]->virtual_device(); -} } // namespace relay } // namespace tvm diff --git a/src/relay/op/memory/on_device.h b/src/relay/op/memory/on_device.h index a7b6cb7cf52a..41edd5f1ba7c 100644 --- a/src/relay/op/memory/on_device.h +++ b/src/relay/op/memory/on_device.h @@ -119,19 +119,6 @@ const NodeType* AsIgnoringOnDevice(const Expr& expr) { */ Function FunctionOnDevice(Function function, Array param_se_scopes, SEScope body_se_scope); -/*! - * \brief As for \p FunctionOnDevice, but returns \p function unchanged if all parameters and - * result \p SEScopes are unconstrained. - */ -Function MaybeFunctionOnDevice(Function function, Array param_se_scopes, - SEScope result_se_scope); - -/*! - * \brief Returns the \p SEScope for the resut of \p function_node, or the unconstrained - * \p SEScope if function does not have the "result_se_scope" annotation. - */ -SEScope GetFunctionResultSEScope(const FunctionNode* function_node); - /*! * \brief Returns the \p SEScope for the \p i'th parameter of \p function_node, or * the unconstrained \p SEScope if function does not have the "param_se_scopes" annotation. diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc index e3d5a821c58e..084b4853c3b7 100644 --- a/src/relay/transforms/device_aware_visitors.cc +++ b/src/relay/transforms/device_aware_visitors.cc @@ -38,7 +38,7 @@ LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional& maybe_mod) if (maybe_mod) { for (const auto& pair : maybe_mod.value()->functions) { if (const auto* function_node = pair.second.as()) { - SEScope se_scope = GetFunctionResultSEScope(function_node); + SEScope se_scope = function_node->virtual_device(); if (!se_scope->IsFullyUnconstrained()) { global_var_se_scopes_.emplace(pair.first, se_scope); } @@ -73,7 +73,7 @@ SEScope LexicalOnDeviceMixin::GetSEScope(const Expr& expr) const { } // else: fallthrough to unconstrained } else { - return GetFunctionResultSEScope(function_node); + return function_node->virtual_device(); } } else { if (!expr_se_scopes_.empty()) { @@ -131,11 +131,11 @@ void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { DeviceAwareVisitExpr_(function_node); } else { // Function parameters come into scope. - for (size_t i = 0; i < function_node->params.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i)); + for (const Var param: function_node->params) { + PushBoundVar(param, param->virtual_device()); } // Entering scope of function body. - PushSEScope(GetFunctionResultSEScope(function_node)); + PushSEScope(function_node->virtual_device()); EnterFunctionBody(); DeviceAwareVisitExpr_(function_node); @@ -144,8 +144,8 @@ void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { ExitFunctionBody(); PopSEScope(); // Function parameters go out of scope. - for (size_t i = 0; i < function_node->params.size(); ++i) { - PopBoundVar(function_node->params[i]); + for (const Var param: function_node->params) { + PopBoundVar(param); } } } @@ -217,11 +217,11 @@ Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { return DeviceAwareVisitExpr_(function_node); } else { // Function parameters come into scope. - for (size_t i = 0; i < function_node->params.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i)); + for (const Var param: function_node->params) { + PushBoundVar(param, param->virtual_device()); } // Entering scope of function body. - PushSEScope(GetFunctionResultSEScope(function_node)); + PushSEScope(function_node->virtual_device()); EnterFunctionBody(); Expr result = DeviceAwareVisitExpr_(function_node); @@ -230,8 +230,8 @@ Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { ExitFunctionBody(); PopSEScope(); // Function parameters go out of scope. - for (size_t i = 0; i < function_node->params.size(); ++i) { - PopBoundVar(function_node->params[i]); + for (const Var param: function_node->params) { + PopBoundVar(param); } return result; diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h index 8cdf0db74ebd..da9196bf8629 100644 --- a/src/relay/transforms/device_aware_visitors.h +++ b/src/relay/transforms/device_aware_visitors.h @@ -142,11 +142,11 @@ class DeviceAwareExprFunctor : public ExprFunctorparams.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i)); + for (const Var param: function_node->params) { + PushBoundVar(param, param->virtual_device()); } // Entering scope of function body. - PushSEScope(GetFunctionResultSEScope(function_node)); + PushSEScope(function_node->virtual_device()); EnterFunctionBody(); DeviceAwareVisitExpr_(function_node); @@ -155,8 +155,8 @@ class DeviceAwareExprFunctor : public ExprFunctorparams.size(); ++i) { - PopBoundVar(function_node->params[i]); + for (const Var param: function_node->params) { + PopBoundVar(param); } } } diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index 8ea5f5dac0a4..3e63b0347716 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -479,14 +479,13 @@ class DeviceAnalyzer : public ExprVisitor { // If the function already has SEScope attributes then we can further constrain the // function's domain to match them. - if (!GetFunctionResultSEScope(function_node)->IsFullyUnconstrained()) { + if (!function_node->virtual_device()->IsFullyUnconstrained()) { std::vector args_and_result; - for (size_t i = 0; i < function_node->params.size(); ++i) { + for (const Var param: function_node->params) { args_and_result.emplace_back(domains_->ForSEScope( - function_node->params[i]->checked_type(), GetFunctionParamSEScope(function_node, i))); + param->checked_type(), param->virtual_device())); } - args_and_result.emplace_back(domains_->ForSEScope(function_node->body->checked_type(), - GetFunctionResultSEScope(function_node))); + args_and_result.emplace_back(domains_->ForSEScope(function_node->body->checked_type(), function_node->virtual_device())); auto annotation_domain = domains_->MakeHigherOrderDomain(std::move(args_and_result)); if (domains_->UnifyOrNull(func_domain, annotation_domain) == nullptr) { // higher-order // TODO(mbs): Proper diagnostics. diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 04c7418f35ff..86b5b7a32126 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -293,16 +293,16 @@ class Fill : ExprFunctor, private transform::Lexi } else { // Keep track of expression and bound variable device types for lexically enclosing // sub-expressions. - PushSEScope(GetFunctionResultSEScope(f)); - for (size_t i = 0; i < f->params.size(); ++i) { - PushBoundVar(f->params[i], GetFunctionParamSEScope(f, i)); + PushSEScope(f->virtual_device()); + for (const Var param: f->params) { + PushBoundVar(param, param->virtual_device()); } EnterFunctionBody(); ret = WithFields(GetRef(f), f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body))); // We are done with this function. ExitFunctionBody(); - for (size_t i = 0; i < f->params.size(); ++i) { - PopBoundVar(f->params[i]); + for (const Var param: f->params) { + PopBoundVar(param); } PopSEScope(); } From 559e873057d9e19c187938508e71b12a416b2020 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 6 Dec 2021 15:52:10 -0800 Subject: [PATCH 05/13] Lint --- .../backend/contrib/cmsisnn/extract_constants.cc | 3 ++- src/relay/backend/contrib/ethosu/preprocess.cc | 8 ++++++-- src/relay/backend/te_compiler.cc | 3 ++- src/relay/backend/vm/compiler.cc | 9 +++++---- src/relay/backend/vm/lambda_lift.cc | 2 +- src/relay/op/memory/on_device.cc | 7 ++++--- src/relay/quantize/calibrate.cc | 3 ++- src/relay/transforms/de_duplicate.cc | 3 ++- src/relay/transforms/defunctionalization.cc | 3 ++- src/relay/transforms/device_aware_visitors.cc | 8 ++++---- src/relay/transforms/device_aware_visitors.h | 4 ++-- src/relay/transforms/device_planner.cc | 9 +++++---- src/relay/transforms/first_order_gradient.cc | 4 ++-- src/relay/transforms/higher_order_gradient.cc | 11 +++++++---- src/relay/transforms/inline.cc | 4 ++-- src/relay/transforms/partition_graph.cc | 3 ++- src/relay/transforms/to_a_normal_form.cc | 7 ++++--- src/relay/transforms/to_cps.cc | 11 +++++++---- 18 files changed, 61 insertions(+), 41 deletions(-) diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc index a3c6f8367a02..f31c3b486e85 100644 --- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc @@ -53,7 +53,8 @@ class ExtractConstantsMutator : public MixedModeMutator { auto new_body = VisitExpr(func->body); functions_.pop_back(); if (function_to_constants_[func].size()) { - func = WithFields(std::move(func), FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_), func->attrs); + func = WithFields(std::move(func), FreeVars(new_body), new_body, func->ret_type, + FreeTypeVars(new_body, mod_), func->attrs); } return func; } diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc b/src/relay/backend/contrib/ethosu/preprocess.cc index 446524aed964..882f9c105884 100644 --- a/src/relay/backend/contrib/ethosu/preprocess.cc +++ b/src/relay/backend/contrib/ethosu/preprocess.cc @@ -177,10 +177,14 @@ class ExternalFuncIOHandler : public ExprRewriter { reshaped_outputs.push_back(CreateFlattenTensor(out)); } auto concat_out = CreateConcatTensor(reshaped_outputs); - Function f = WithFields(std::move(func), std::move(params), std::move(concat_out), std::move(concat_out->checked_type_), Array() /* erase type params */); + Function f = WithFields(std::move(func), std::move(params), std::move(concat_out), + std::move(concat_out->checked_type_), + Array() /* erase type params */); return InferType(f, this->module_); } else { - Function f = WithFields(std::move(func), std::move(params), std::move(core_compute_expr), std::move(core_compute_expr->checked_type_), Array() /* erase type params */); + Function f = WithFields(std::move(func), std::move(params), std::move(core_compute_expr), + std::move(core_compute_expr->checked_type_), + Array() /* erase type params */); return InferType(f, this->module_); } } diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 2484f1eda8cd..273bbae8c477 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -164,7 +164,8 @@ class TECompilerImpl : public TECompilerNode { for (const auto& kv2 : kv1.second->cached_func->funcs->functions) { if (const auto* function_node = kv2.second.as()) { // Abandon the existing function annotations. - Function function = WithFields(GetRef(function_node), {}, {}, {}, {}, /* erase attributes */ DictAttrs()); + Function function = WithFields(GetRef(function_node), {}, {}, {}, {}, + /* erase attributes */ DictAttrs()); // Mark function as 'extern' using the "ExternalSymbol" attribute. function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint); module->Add(kv2.first, function); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 4bae45ec703e..d8d415f631be 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -254,7 +254,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { Array params; params.reserve(func->params.size() + inner_func->params.size()); param_device_indexes.reserve(func->params.size() + inner_func->params.size()); - for (const Var param: func->params) { + for (const Var param : func->params) { params.push_back(param); param_device_indexes.push_back(GetDeviceIndex(param->virtual_device())); } @@ -270,12 +270,13 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { for (const auto& tyvar : inner_func->type_params) { type_params.push_back(tyvar); } - Function flattened_func = WithFields(std::move(func), std::move(params), inner_func->body, inner_func->ret_type, - std::move(type_params), {}, inner_func->virtual_device()); + Function flattened_func = + WithFields(std::move(func), std::move(params), inner_func->body, inner_func->ret_type, + std::move(type_params), {}, inner_func->virtual_device()); } else { param_device_indexes.reserve(func->params.size()); - for (const Var param: func->params) { + for (const Var param : func->params) { param_device_indexes.push_back(GetDeviceIndex(param->virtual_device())); } VisitExpr(func); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 673ef0eeeec7..71623b94f29e 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -190,7 +190,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { // construct the "closure" function with fully annotated arguments, no longer relying // on type inference. size_t before_arity = body->params.size(); - auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map)); + auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map)); size_t after_arity = rebound_body->params.size(); CHECK_EQ(before_arity, after_arity); diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index aa355bc9b7fd..5f21a1d868b6 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -127,17 +127,18 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { Function FunctionOnDevice(Function function, Array param_se_scopes, SEScope result_se_scope) { - ICHECK(function->params.size() == param_se_scopes.size()) << "ParamSEScopes must be the same size as the function parameters."; + ICHECK(function->params.size() == param_se_scopes.size()) + << "ParamSEScopes must be the same size as the function parameters."; Array new_params; for (size_t i = 0; i < function->params.size(); i++) { Var param = function->params[i]; new_params.push_back(WithFields(std::move(param), {}, {}, std::move(param_se_scopes[i]))); } - return WithFields(std::move(function), std::move(new_params), {}, {}, {}, {}, std::move(result_se_scope)); + return WithFields(std::move(function), std::move(new_params), {}, {}, {}, {}, + std::move(result_se_scope)); } TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice); - } // namespace relay } // namespace tvm diff --git a/src/relay/quantize/calibrate.cc b/src/relay/quantize/calibrate.cc index 5bc64bf7368a..78745c0824cb 100644 --- a/src/relay/quantize/calibrate.cc +++ b/src/relay/quantize/calibrate.cc @@ -152,7 +152,8 @@ class StatsCollector : private ExprMutator { const FunctionNode* func = new_e.as(); ICHECK(func) << "Input shoule be Function"; Expr new_body = Tuple(std::move(profile_data_)); - return WithFields(GetRef(func), FreeVars(new_body), std::move(new_body), NullValue()); + return WithFields(GetRef(func), FreeVars(new_body), std::move(new_body), + NullValue()); } private: diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index d9486ce9bf95..1380b2591701 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -90,7 +90,8 @@ Expr DeDup(const Expr& e) { for (const Var& param : func_node->params) { params.push_back(Fresh(param)); } - return WithFields(GetRef(func_node), std::move(params), VisitExpr(func_node->body), VisitType(func_node->ret_type), std::move(type_params)); + return WithFields(GetRef(func_node), std::move(params), VisitExpr(func_node->body), + VisitType(func_node->ret_type), std::move(type_params)); } Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); } diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index c3cd994e44dd..353d31c01ecc 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -380,7 +380,8 @@ class DefuncMutator : public ExprMutator { map.Set(f->type_params[i], type_args[i]); } // copy with typevars removed - auto copy = TypeSubst(WithFields(std::move(f), {}, {}, {}, /* erase typeparams */ Array()), map); + auto copy = TypeSubst( + WithFields(std::move(f), {}, {}, {}, /* erase typeparams */ Array()), map); return Downcast(copy); } diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc index 084b4853c3b7..139fa03d2187 100644 --- a/src/relay/transforms/device_aware_visitors.cc +++ b/src/relay/transforms/device_aware_visitors.cc @@ -131,7 +131,7 @@ void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { DeviceAwareVisitExpr_(function_node); } else { // Function parameters come into scope. - for (const Var param: function_node->params) { + for (const Var param : function_node->params) { PushBoundVar(param, param->virtual_device()); } // Entering scope of function body. @@ -144,7 +144,7 @@ void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { ExitFunctionBody(); PopSEScope(); // Function parameters go out of scope. - for (const Var param: function_node->params) { + for (const Var param : function_node->params) { PopBoundVar(param); } } @@ -217,7 +217,7 @@ Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { return DeviceAwareVisitExpr_(function_node); } else { // Function parameters come into scope. - for (const Var param: function_node->params) { + for (const Var param : function_node->params) { PushBoundVar(param, param->virtual_device()); } // Entering scope of function body. @@ -230,7 +230,7 @@ Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { ExitFunctionBody(); PopSEScope(); // Function parameters go out of scope. - for (const Var param: function_node->params) { + for (const Var param : function_node->params) { PopBoundVar(param); } diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h index da9196bf8629..7bc1f8ebcd8d 100644 --- a/src/relay/transforms/device_aware_visitors.h +++ b/src/relay/transforms/device_aware_visitors.h @@ -142,7 +142,7 @@ class DeviceAwareExprFunctor : public ExprFunctorparams) { + for (const Var param : function_node->params) { PushBoundVar(param, param->virtual_device()); } // Entering scope of function body. @@ -155,7 +155,7 @@ class DeviceAwareExprFunctor : public ExprFunctorparams) { + for (const Var param : function_node->params) { PopBoundVar(param); } } diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index 3e63b0347716..ff73edc486f4 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -481,11 +481,12 @@ class DeviceAnalyzer : public ExprVisitor { // function's domain to match them. if (!function_node->virtual_device()->IsFullyUnconstrained()) { std::vector args_and_result; - for (const Var param: function_node->params) { - args_and_result.emplace_back(domains_->ForSEScope( - param->checked_type(), param->virtual_device())); + for (const Var param : function_node->params) { + args_and_result.emplace_back( + domains_->ForSEScope(param->checked_type(), param->virtual_device())); } - args_and_result.emplace_back(domains_->ForSEScope(function_node->body->checked_type(), function_node->virtual_device())); + args_and_result.emplace_back(domains_->ForSEScope(function_node->body->checked_type(), + function_node->virtual_device())); auto annotation_domain = domains_->MakeHigherOrderDomain(std::move(args_and_result)); if (domains_->UnifyOrNull(func_domain, annotation_domain) == nullptr) { // higher-order // TODO(mbs): Proper diagnostics. diff --git a/src/relay/transforms/first_order_gradient.cc b/src/relay/transforms/first_order_gradient.cc index 9c1d4baa0617..0509469d69e6 100644 --- a/src/relay/transforms/first_order_gradient.cc +++ b/src/relay/transforms/first_order_gradient.cc @@ -307,8 +307,8 @@ Pass FirstOrderGradient() { }); return Pair(res.forward, grad_tuple); }); - ad_mod->Update(pr.first, - WithFields(GetRef(func), func->params, std::move(body), GradRetType(GetRef(func)))); + ad_mod->Update(pr.first, WithFields(GetRef(func), func->params, std::move(body), + GradRetType(GetRef(func)))); } return ad_mod; diff --git a/src/relay/transforms/higher_order_gradient.cc b/src/relay/transforms/higher_order_gradient.cc index 68a34c28c967..fa519d4fa547 100644 --- a/src/relay/transforms/higher_order_gradient.cc +++ b/src/relay/transforms/higher_order_gradient.cc @@ -346,7 +346,8 @@ struct ReverseAD : ExprMutator { params.push_back(Downcast(VisitExpr(p))); } params.push_back(bp); - Function f = WithFields(std::move(orig_f), std::move(params), VisitExpr(orig_f->body), VisitType(orig_f->ret_type)); + Function f = WithFields(std::move(orig_f), std::move(params), VisitExpr(orig_f->body), + VisitType(orig_f->ret_type)); std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl; mod.value()->Add(gv, f); } @@ -360,8 +361,9 @@ struct ReverseAD : ExprMutator { } auto new_bp = Var("bp", bpt); params.push_back(new_bp); - return WithFields(GetRef(func_node), std::move(params), ReverseAD(mod, new_bp, ad_vars, ad_gvars)(func_node->body), - VisitType(func_node->ret_type)); + return WithFields(GetRef(func_node), std::move(params), + ReverseAD(mod, new_bp, ad_vars, ad_gvars)(func_node->body), + VisitType(func_node->ret_type)); } Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; } @@ -455,7 +457,8 @@ Expr Gradient(const Expr& re, const Optional& mod) { }; return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret)); }); - Function ret = WithFields(GetRef(f), f->params, std::move(body), GradRetType(GetRef(f))); + Function ret = + WithFields(GetRef(f), f->params, std::move(body), GradRetType(GetRef(f))); CheckFeature(ret, FeatureSet::All() - fGraph); return std::move(ret); } diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index 9ddfb53dd3c9..b60eda47c73d 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -130,8 +130,8 @@ class Inliner : ExprMutator { const auto* fn = base_func.as(); ICHECK(fn) << "Expected to work on a Relay function."; - // TODO(@electriclilies): If Function is a COW node, then if it gets written to we shouldn't have any sharing, right? - // So we don't need to reconstruct? + // TODO(@electriclilies): If Function is a COW node, then if it gets written to we shouldn't + // have any sharing, right? So we don't need to reconstruct? Function func = WithFields(GetRef(fn)); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index d82597adb6d8..7c7804a1ed1c 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -521,7 +521,8 @@ class NameMangleExtFuncs : public MixedModeMutator { auto new_dict = func->attrs->dict; new_dict.Set(tvm::attr::kGlobalSymbol, String(relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint)))); - func = WithFields(func, func->params, VisitExpr(func->body), func->ret_type, func->type_params, DictAttrs(new_dict)); + func = WithFields(func, func->params, VisitExpr(func->body), func->ret_type, + func->type_params, DictAttrs(new_dict)); new_module->Add(mangled_gvars_[pair.first->name_hint], func); } else { func = WithFields(func, func->params, VisitExpr(func->body)); diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 86b5b7a32126..7a08ef9cba4d 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -294,14 +294,15 @@ class Fill : ExprFunctor, private transform::Lexi // Keep track of expression and bound variable device types for lexically enclosing // sub-expressions. PushSEScope(f->virtual_device()); - for (const Var param: f->params) { + for (const Var param : f->params) { PushBoundVar(param, param->virtual_device()); } EnterFunctionBody(); - ret = WithFields(GetRef(f), f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body))); + ret = WithFields(GetRef(f), f->params, + GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body))); // We are done with this function. ExitFunctionBody(); - for (const Var param: f->params) { + for (const Var param : f->params) { PopBoundVar(param); } PopSEScope(); diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index ad58123cadbe..bb5c02e6fad2 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -272,8 +272,9 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, new_params.push_back(remap(v)); } new_params.push_back(k); - return WithFields(std::move(f), std::move(new_params), mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }), - std::move(answer)); + return WithFields(std::move(f), std::move(new_params), + mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }), + std::move(answer)); } Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { @@ -299,7 +300,8 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { Function ret = ToCPS(f, m, cm, &var, answer); auto new_type_params = ret->type_params; new_type_params.push_back(answer); - return WithFields(std::move(ret), ret->params, ret->body, ret->ret_type, std::move(new_type_params)); + return WithFields(std::move(ret), ret->params, ret->body, ret->ret_type, + std::move(new_type_params)); } Function ToCPS(const Function& f, const IRModule& m) { @@ -343,7 +345,8 @@ Function UnCPS(const Function& f) { type_args.push_back(new_ret_type); Call call = Call(f, args, {}, type_args); // How do I fix this? - return WithFields(f, std::move(new_params), call, std::move(new_ret_type), std::move(new_type_params)); + return WithFields(f, std::move(new_params), call, std::move(new_ret_type), + std::move(new_type_params)); } TVM_REGISTER_GLOBAL("relay._transform.to_cps") From cabbde51360358ceb8ce8a89ff0c70394bf05b4e Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 6 Dec 2021 16:01:21 -0800 Subject: [PATCH 06/13] lint & delete unneeded header --- src/relay/op/memory/on_device.h | 6 ------ src/relay/transforms/partial_eval.cc | 23 ++++++++++++----------- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/src/relay/op/memory/on_device.h b/src/relay/op/memory/on_device.h index 41edd5f1ba7c..af20882469ba 100644 --- a/src/relay/op/memory/on_device.h +++ b/src/relay/op/memory/on_device.h @@ -119,12 +119,6 @@ const NodeType* AsIgnoringOnDevice(const Expr& expr) { */ Function FunctionOnDevice(Function function, Array param_se_scopes, SEScope body_se_scope); -/*! - * \brief Returns the \p SEScope for the \p i'th parameter of \p function_node, or - * the unconstrained \p SEScope if function does not have the "param_se_scopes" annotation. - */ -SEScope GetFunctionParamSEScope(const FunctionNode* function_node, size_t i); - } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index bb2a6fc97b46..2e006cb8b3a7 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -827,17 +827,18 @@ class PartialEvaluator : public ExprFunctor Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) { return store_.Extend([&]() { store_.Invalidate(); - return WithFields(std::move(func), func->params, LetList::With([&](LetList* ll) { - std::vector pv; - for (const auto& v : func->params) { - pv.push_back(NoStatic(v)); - } - tvm::Array type_args; - for (const auto& tp : func->type_params) { - type_args.push_back(tp); - } - return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; - })); + return WithFields( + std::move(func), func->params, LetList::With([&](LetList* ll) { + std::vector pv; + for (const auto& v : func->params) { + pv.push_back(NoStatic(v)); + } + tvm::Array type_args; + for (const auto& tp : func->type_params) { + type_args.push_back(tp); + } + return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; + })); }); } From 7a0629d04cdf7c25cc73ad21d7df7e7b7365ef8a Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 6 Dec 2021 16:14:47 -0800 Subject: [PATCH 07/13] Remove FunctionOnDevice from everywhere except python tests --- src/relay/op/memory/on_device.h | 6 ------ src/relay/transforms/device_planner.cc | 11 +++++------ 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/src/relay/op/memory/on_device.h b/src/relay/op/memory/on_device.h index af20882469ba..c35d12c331e4 100644 --- a/src/relay/op/memory/on_device.h +++ b/src/relay/op/memory/on_device.h @@ -113,12 +113,6 @@ const NodeType* AsIgnoringOnDevice(const Expr& expr) { return props.body.as(); } -/*! - * \brief Returns \p function annotated with "param_se_scopes" and "result_se_scope" - * attributes capturing parameter and result \p SEScopes respectively. - */ -Function FunctionOnDevice(Function function, Array param_se_scopes, SEScope body_se_scope); - } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index ff73edc486f4..a75158381fc5 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -803,12 +803,12 @@ class DeviceCapturer : public ExprMutator { ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); SEScope result_se_scope = domains_->ResultSEScope(func_domain); ICHECK(!result_se_scope->IsFullyUnconstrained()); - Array param_se_scopes; - param_se_scopes.reserve(function_node->params.size()); + Array new_params; + new_params.reserve(function_node->params.size()); for (size_t i = 0; i < function_node->params.size(); ++i) { SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i)); ICHECK(!param_se_scope->IsFullyUnconstrained()); - param_se_scopes.push_back(param_se_scope); + new_params.push_back(WithFields(function_node->params[i], {}, {}, std::move(param_se_scope))); } // Rewrite the body. Note that the body may have begun with an "on_device" so @@ -818,9 +818,8 @@ class DeviceCapturer : public ExprMutator { /*expected_se_scope=*/result_se_scope, /*child_se_scope=*/GetSEScope(function_node->body), function_node->body); - Function func = WithFields(GetRef(function_node), std::move(function_node->params), - std::move(body)); - return FunctionOnDevice(func, std::move(param_se_scopes), std::move(result_se_scope)); + return WithFields(GetRef(function_node), std::move(new_params), + std::move(body), {}, {}, {}, std::move(result_se_scope)); } Expr VisitExpr_(const CallNode* call_node) final { From 47c1fd2ea88c7b23b8781d68214d52129a9f5d03 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 6 Dec 2021 16:18:04 -0800 Subject: [PATCH 08/13] Remove comment --- src/relay/transforms/to_cps.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index bb5c02e6fad2..17d73b69d81e 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -343,9 +343,7 @@ Function UnCPS(const Function& f) { type_args.push_back(tp); } type_args.push_back(new_ret_type); - Call call = Call(f, args, {}, type_args); - // How do I fix this? - return WithFields(f, std::move(new_params), call, std::move(new_ret_type), + return WithFields(f, std::move(new_params), Call(f, args, {}, type_args), std::move(new_ret_type), std::move(new_type_params)); } From 6e25832cf4b23af209b1e7cb006e1afb7f37f718 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 6 Dec 2021 16:19:06 -0800 Subject: [PATCH 09/13] lint --- src/relay/transforms/device_planner.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index a75158381fc5..e7ff8e59981d 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -818,8 +818,8 @@ class DeviceCapturer : public ExprMutator { /*expected_se_scope=*/result_se_scope, /*child_se_scope=*/GetSEScope(function_node->body), function_node->body); - return WithFields(GetRef(function_node), std::move(new_params), - std::move(body), {}, {}, {}, std::move(result_se_scope)); + return WithFields(GetRef(function_node), std::move(new_params), std::move(body), {}, + {}, {}, std::move(result_se_scope)); } Expr VisitExpr_(const CallNode* call_node) final { From 3ec9c1514c9e69a95d7e2499aea840bace90068a Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Mon, 6 Dec 2021 16:26:47 -0800 Subject: [PATCH 10/13] remove funciton on device completely --- src/relay/op/memory/on_device.cc | 15 --------------- .../python/relay/op/annotation/test_annotation.py | 12 ------------ 2 files changed, 27 deletions(-) diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 5f21a1d868b6..6b9c340070e3 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -125,20 +125,5 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { return {}; } -Function FunctionOnDevice(Function function, Array param_se_scopes, - SEScope result_se_scope) { - ICHECK(function->params.size() == param_se_scopes.size()) - << "ParamSEScopes must be the same size as the function parameters."; - Array new_params; - for (size_t i = 0; i < function->params.size(); i++) { - Var param = function->params[i]; - new_params.push_back(WithFields(std::move(param), {}, {}, std::move(param_se_scopes[i]))); - } - return WithFields(std::move(function), std::move(new_params), {}, {}, {}, {}, - std::move(result_se_scope)); -} - -TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice); - } // namespace relay } // namespace tvm diff --git a/tests/python/relay/op/annotation/test_annotation.py b/tests/python/relay/op/annotation/test_annotation.py index 8ba91976523a..a51312ed4ea6 100644 --- a/tests/python/relay/op/annotation/test_annotation.py +++ b/tests/python/relay/op/annotation/test_annotation.py @@ -51,18 +51,6 @@ def test_on_device_is_fixed(): assert call.attrs.is_fixed -def test_function_on_device(): - x = relay.Var("x") - y = relay.Var("y") - f = relay.Function([x, y], relay.add(x, y)) - func = relay.annotation.function_on_device(f, ["cpu", "cuda"], "cuda") - assert isinstance(func, relay.Function) - assert len(func.attrs["param_se_scopes"]) == 2 - assert func.attrs["param_se_scopes"][0].device_type_int == 1 # ie kDLCPU - assert func.attrs["param_se_scopes"][1].device_type_int == 2 # ie kDLCUDA - assert func.attrs["result_se_scope"].device_type_int == 2 # ie KDLCUDA - - if __name__ == "__main__": import sys From 98ccd33a8e922960976ba2035922f6585cec1893 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 8 Dec 2021 13:23:18 -0800 Subject: [PATCH 11/13] progress removing function on device --- python/tvm/relay/op/annotation/annotation.py | 24 ------ src/relay/backend/te_compiler.cc | 3 + src/relay/op/memory/on_device.cc | 4 + src/relay/transforms/device_aware_visitors.cc | 7 ++ src/relay/transforms/virtual_device_check.cc | 78 +++++++++++++++++++ src/relay/transforms/virtual_device_check.h | 72 +++++++++++++++++ tests/python/relay/test_pass_fold_constant.py | 5 -- 7 files changed, 164 insertions(+), 29 deletions(-) create mode 100644 src/relay/transforms/virtual_device_check.cc create mode 100644 src/relay/transforms/virtual_device_check.h diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index cf70dc6e267e..b794b4d18e78 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -56,30 +56,6 @@ def on_device(data, device, is_fixed=False): return _make.OnDevice(data, _make_se_scope(device), is_fixed) -def function_on_device(function, param_devices, result_device): - """Annotates a Relay function with the device types on which its parameters and result should - be stored. - - Parameters - ---------- - function : tvm.relay.Function - The function to be annotated. - - param_devices : Array[Union[:py:class:`Device`, str]] - The devices for each parameter. - - result_device: Union[:py:class:`Device`, str] - The device for the function result. - - Returns - ------- - result : tvm.relay.Function - The annotated function. - """ - return _make.FunctionOnDevice( - function, [_make_se_scope(d) for d in param_devices], _make_se_scope(result_device) - ) - def stop_fusion(data): """Annotate an expression to prevent it being fused with following expressions. diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 273bbae8c477..93591d80dcee 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -746,6 +746,9 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { target = Target("ext_dev"); } else { // The target corresponding to the call_node expression's annotation. + // TODO(@electriclilies): The call node here should always be wrapped in OnDevice I think? Not sure why it's not getting put in. + + // SEScope se_scope = call_node->op->virtual_device(); SEScope se_scope = GetSEScope(GetRef(call_node)); ICHECK(!se_scope->IsFullyUnconstrained()); target = se_scope->target; diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 6b9c340070e3..3feccebad4a2 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -58,6 +58,7 @@ Expr OnDevice(Expr expr, SEScope se_scope, bool is_fixed) { TVM_REGISTER_GLOBAL("relay.op.annotation._make.OnDevice").set_body_typed(OnDevice); Expr MaybeOnDevice(Expr expr, SEScope se_scope, bool is_fixed) { + VLOG(1) << "MaybeOnDevice"; if (se_scope->IsFullyUnconstrained()) { // Nothing to annotate with. return expr; @@ -101,7 +102,10 @@ RELAY_REGISTER_OP("on_device") .set_attr("TNonComputational", true); OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { + // This needs to extract the correct se scope from the function. + std::cout << "GetOnDeviceProps for: " << call_node; if (call_node->op == OnDeviceOp()) { + std::cout << "Is OnDevice op" << std::endl; ICHECK_EQ(call_node->args.size(), 1) << "on_device expects one argument"; ICHECK(call_node->attrs.defined()) << "on_device requires attributes"; const auto* on_device_attrs = call_node->attrs.as(); diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc index 139fa03d2187..c4e6c4d3064b 100644 --- a/src/relay/transforms/device_aware_visitors.cc +++ b/src/relay/transforms/device_aware_visitors.cc @@ -48,10 +48,14 @@ LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional& maybe_mod) } SEScope LexicalOnDeviceMixin::GetSEScope(const Expr& expr) const { + VLOG(1) << "GetSEScope for " << expr; OnDeviceProps props = GetOnDeviceProps(expr); + VLOG(1) << "props.body.defined(): " << props.body.defined(); if (props.body.defined() && props.is_fixed) { + VLOG(1) << "OnDeviceProps SEScope"; return props.se_scope; } else if (const auto* var_node = expr.as()) { + VLOG(1) << "VarNode SEScope"; // Lookup variable binding. auto itr = var_se_scopes_.find(GetRef(var_node)); if (itr != var_se_scopes_.end()) { @@ -59,6 +63,7 @@ SEScope LexicalOnDeviceMixin::GetSEScope(const Expr& expr) const { } // else: fallthrough to unconstrained } else if (const auto* global_var_node = expr.as()) { + VLOG(1) << "GlobalVarNode SEScope"; // Lookup global variable. auto itr = global_var_se_scopes_.find(GetRef(global_var_node)); if (itr != global_var_se_scopes_.end()) { @@ -66,6 +71,7 @@ SEScope LexicalOnDeviceMixin::GetSEScope(const Expr& expr) const { } // else: fallthrough to unconstrained } else if (const auto* function_node = expr.as()) { + VLOG(1) << "FunctionNode SEScope"; if (function_node->HasNonzeroAttr(attr::kPrimitive)) { if (!expr_se_scopes_.empty()) { // Use the currently in-scope device type. @@ -82,6 +88,7 @@ SEScope LexicalOnDeviceMixin::GetSEScope(const Expr& expr) const { } // else: fallthrough to unconstrained } + VLOG(1) << "Falling back to FullyUnconstrained SEScope"; return SEScope::FullyUnconstrained(); } diff --git a/src/relay/transforms/virtual_device_check.cc b/src/relay/transforms/virtual_device_check.cc new file mode 100644 index 000000000000..a7c68755ed16 --- /dev/null +++ b/src/relay/transforms/virtual_device_check.cc @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./virtual_device_check.h" +#include + +namespace tvm { +using tvm::relay::transform::CreateFunctionPass; +using tvm::transform::PassContext; + +// TODO(@jroesch, @junru): we need to deal with unique spans for global/var. +void DeviceChecker::VisitExpr_(const VarNode* op) { + ICHECK((op->virtual_device_.defined())) << "VarNode's virtual device should not be null"; +} +void DeviceChecker::VisitExpr_(const GlobalVarNode* op) {} +void DeviceChecker::VisitExpr_(const ConstantNode* op) {} + +void DeviceChecker::VisitExpr_(const TupleNode* op) { ExprVisitor::VisitExpr_(op); } + +void DeviceChecker::VisitExpr_(const FunctionNode* op) { + ICHECK(op->virtual_device_.defined()); + ExprVisitor::VisitExpr_(op); } + +void DeviceChecker::VisitExpr_(const CallNode* op) { ExprVisitor::VisitExpr_(op); } + +void DeviceChecker::VisitExpr_(const LetNode* op) { ExprVisitor::VisitExpr_(op); } + +void DeviceChecker::VisitExpr_(const IfNode* op) { ExprVisitor::VisitExpr_(op); } + +void DeviceChecker::VisitExpr_(const OpNode* op) {} + +void DeviceChecker::VisitExpr_(const TupleGetItemNode* op) { ExprVisitor::VisitExpr_(op); } + +void DeviceChecker::VisitExpr_(const RefCreateNode* op) { ExprVisitor::VisitExpr_(op); } + +void DeviceChecker::VisitExpr_(const RefReadNode* op) { ExprVisitor::VisitExpr_(op); } + +void DeviceChecker::VisitExpr_(const RefWriteNode* op) { ExprVisitor::VisitExpr_(op); } + +void DeviceChecker::VisitExpr_(const ConstructorNode* op) {} // ExprVisitor::VisitExpr_(op); } + +void DeviceChecker::VisitExpr_(const MatchNode* op) { ExprVisitor::VisitExpr_(op); } + + +void DeviceChecker::VisitType(const Type& t) {} +void DeviceChecker::VisitClause(const Clause& c) {} +void DeviceChecker::VisitPattern(const Pattern& c) {} + +Pass VirtualDeviceCheck() { + return CreateFunctionPass( + [](const Function& func, const IRModule& mod, const PassContext& ctx) { + ICHECK(ctx->diag_ctx) << "Diagnostic context must be set."; + DeviceChecker checker; + checker.VisitExpr(func); + return func; + }, + 0, "VirtualDeviceCheck", {}); +} + +TVM_REGISTER_GLOBAL("VirtualDeviceCheck").set_body_typed([]() { return VirtualDeviceCheck(); }); + +} // namespace tvm \ No newline at end of file diff --git a/src/relay/transforms/virtual_device_check.h b/src/relay/transforms/virtual_device_check.h new file mode 100644 index 000000000000..3539757dfed5 --- /dev/null +++ b/src/relay/transforms/virtual_device_check.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file span_check.h + * \brief Check that the Relay IR has correctly attached span information. + */ + +#ifndef TVM_VIRTUAL_DEVICE_CHECK_H_ +#define TVM_VIRTUAL_DEVICE_CHECK_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { + +using namespace tvm::relay; +using tvm::transform::Pass; + +struct DeviceChecker : ExprVisitor { + + void VisitExpr(const Expr& expr) override; + void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const GlobalVarNode* op) override; + void VisitExpr_(const ConstantNode* op) override; + void VisitExpr_(const TupleNode* op) override; + void VisitExpr_(const FunctionNode* op) override; + void VisitExpr_(const CallNode* op) override; + void VisitExpr_(const LetNode* op) override; + void VisitExpr_(const IfNode* op) override; + void VisitExpr_(const OpNode* op) override; + void VisitExpr_(const TupleGetItemNode* op) override; + void VisitExpr_(const RefCreateNode* op) override; + void VisitExpr_(const RefReadNode* op) override; + void VisitExpr_(const RefWriteNode* op) override; + void VisitExpr_(const ConstructorNode* op) override; + void VisitExpr_(const MatchNode* op) override; + void VisitType(const Type& t) override; + void VisitClause(const Clause& c) override; + void VisitPattern(const Pattern& c) override; + void VisitSpan(const Span& span) override; +}; + +Pass VirtualDeviceCheck(); + +} // namespace tvm +#endif // TVM_VIRTUAL_DEVICE_CHECK_H_ diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py index 3a5f458d5970..b9e3038bbcb1 100644 --- a/tests/python/relay/test_pass_fold_constant.py +++ b/tests/python/relay/test_pass_fold_constant.py @@ -22,11 +22,6 @@ from tvm.relay.testing import run_infer_type, create_workload -def annot_func(f): - """Returns f with arg/result device attributes for the argument and result.""" - return relay.op.annotation.function_on_device(f, [tvm.cpu()], tvm.cpu()) - - def annot_expr(e): """Returns e wrapped with an on_device annotation.""" return relay.op.annotation.on_device(e, tvm.cpu(), is_fixed=True) From 74f23df208f578c73db9fe34bb82f966e93ce1d4 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 8 Dec 2021 13:32:24 -0800 Subject: [PATCH 12/13] fix comments --- include/tvm/ir/expr.h | 3 +- include/tvm/relay/expr.h | 59 ++++++++++++++++++++++------------------ 2 files changed, 33 insertions(+), 29 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 310d20812ade..d33606676944 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -181,8 +181,7 @@ class RelayExprNode : public BaseExprNode { * The SEScope's Target field describes how the body of the function should be compiled. * * \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular - * import. We can forward-declare the SEScope type for the getter function, but not for the field - * itself. + * import. */ mutable ObjectRef virtual_device_; diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 0f45970c9ea1..03200d3a3dfb 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -246,11 +246,12 @@ class Var : public Expr { * \param opt_type_annotation The (optional) type_annotation for the copied var. If none, * ret_var->type_annotation = var->type_annotation. * \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none, - * ret_tuple->virtual_device = tuple->virtual_device. \param opt_span The (optional) span for the - * copied var. If none, ret_var->span = var->span. \return If all properties are null or the same as - * the property in the input var (i.e., opt_vid is null or opt_vid.value() == var->vid, etc.), then - * we return var. Otherwise, we return a copy of call with the different fields overwritten. (i.e., - * if opt_vid.value() != var->vid, then ret_var->vid = opt_.value()). + * ret_tuple->virtual_device = tuple->virtual_device. + * \param opt_span The (optional) span for the copied var. If none, ret_var->span = var->span. + * \return If all properties are null or the same as the property in the input var (i.e., opt_vid is + * null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise, we return a copy of + * call with the different fields overwritten. (i.e., if opt_vid.value() != var->vid, then + * ret_var->vid = opt_.value()). */ Var WithFields(Var var, Optional opt_vid = Optional(), Optional opt_type_annotation = Optional(), @@ -370,11 +371,12 @@ class Call : public Expr { * \param opt_type_args The (optional) type args for the copied call. If none, * ret_call->type_args = call->type_args. * \param opt_virtual_device The (optional) virtual_device for the copied call. If none, - * ret_call->virtual_device = call->virtual_device. \param opt_span The (optional) span for the - * copied call. If none, ret_call->span = call->span. \return If all properties are null or the same - * as the property in the input call (i.e., opt_op is null or opt_op.value() == call->op, etc.), - * then we return call. Otherwise, we return a copy of call with the different fields overwritten. - * (i.e., if opt_op.value() != call->op, then ret_call->op = opt_op.value()). + * ret_call->virtual_device = call->virtual_device. + * \param opt_span The (optional) span for the copied call. If none, ret_call->span = call->span. + * \return If all properties are null or the same as the property in the input call (i.e., opt_op is + * null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we return a copy of + * call with the different fields overwritten. (i.e., if opt_op.value() != call->op, then + * ret_call->op = opt_op.value()). */ Call WithFields(Call call, Optional opt_op = Optional(), Optional> opt_args = Optional>(), @@ -466,11 +468,12 @@ class Let : public Expr { * \param opt_value The (optional) value for the copied let. If none, ret_let->args = let->args. * \param opt_body The (optional) body for the copied let. If none, ret_let->attrs = let->attrs. * \param opt_virtual_device The (optional) virtual_device for the copied let. If none, - * ret_let->virtual_device = let->virtual_device. \param opt_span The (optional) span for the copied - * let. If none, ret_let->span = let->span. \return If all properties are null or the same as the - * property in the input let (i.e., opt_var is null or opt_var.value() == let->var, etc.), then we - * return let. Otherwise, we return a copy of let with the different fields overwritten. (i.e., if - * opt_var.value() != let->var, then ret_let->var = opt_var.value()). + * ret_let->virtual_device = let->virtual_device. + * \param opt_span The (optional) span for the copied let. If none, ret_let->span = let->span. + * \return If all properties are null or the same as the property in the input let (i.e., opt_var is + * null or opt_var.value() == let->var, etc.), then we return let. Otherwise, we return a copy of + * let with the different fields overwritten. (i.e., if opt_var.value() != let->var, then + * ret_let->var = opt_var.value()). */ Let WithFields(Let let, Optional opt_var = Optional(), Optional opt_value = Optional(), @@ -740,12 +743,13 @@ class RefRead : public Expr { * ref_read->ref. * \param opt_virtual_device * The (optional) virtual_device for the copied ref_read. If none, ret_ref_read->virtual_device = - * ref_read->virtual_device. \param opt_span The (optional) span for the copied ref_read. If none, - * ret_ref_read->span = ref_read->span. \return If all properties are null or the same as the - * property in the input ref_read (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), - * then we return ref_read. Otherwise, we return a copy of ref_read with the different fields - * overwritten. (i.e., if opt_ref.value() != ref_read->ref, then ret_ref_read->ref = - * opt_ref.value()). + * ref_read->virtual_device. + * \param opt_span The (optional) span for the copied ref_read. If none, ret_ref_read->span = + * ref_read->span. + * \return If all properties are null or the same as the property in the input + * ref_read (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return + * ref_read. Otherwise, we return a copy of ref_read with the different fields overwritten. (i.e., + * if opt_ref.value() != ref_read->ref, then ret_ref_read->ref = opt_ref.value()). */ RefRead WithFields(RefRead ref_read, Optional opt_ref = Optional(), Optional opt_virtual_device = Optional(), @@ -806,12 +810,13 @@ class RefWrite : public Expr { * ret_ref_write->value = ref_write->value. * \param opt_virtual_device * The (optional) virtual_device for the copied ref_write. If none, ret_ref_write->virtual_device = - * ref_write->virtual_device. \param opt_span The (optional) span for the copied ref_write. If none, - * ret_ref_write->span = ref_write->span. \return If all properties are null or the same as the - * property in the input ref_write (i.e., opt_ref is null or opt_ref.value() == ref_write->ref, - * etc.), then we return ref_write. Otherwise, we return a copy of ref_write with the different - * fields overwritten. (i.e., if ref_write.value() != ref_write->ref, then ret_ref_write->ref = - * opt_ref.value()). + * ref_write->virtual_device. + * \param opt_span The (optional) span for the copied ref_write. If none, ret_ref_write->span = + * ref_write->span. + * \return If all properties are null or the same as the property in the input ref_write (i.e., + * opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write. Otherwise, + * we return a copy of ref_write with the different fields overwritten. (i.e., if ref_write.value() + * != ref_write->ref, then ret_ref_write->ref = opt_ref.value()). */ RefWrite WithFields(RefWrite ref_write, Optional opt_ref = Optional(), Optional opt_value = Optional(), From a387e35797c1c56bbbeaf22a8217995682360c85 Mon Sep 17 00:00:00 2001 From: Lily Orth-Smith Date: Wed, 8 Dec 2021 16:41:16 -0800 Subject: [PATCH 13/13] debugging --- src/relay/backend/graph_executor_codegen.cc | 7 ++++++- src/relay/backend/te_compiler.cc | 20 +++++++++++++++---- src/relay/op/memory/on_device.cc | 1 - src/relay/transforms/device_aware_visitors.cc | 4 +++- src/relay/transforms/device_planner.cc | 2 ++ src/relay/transforms/to_a_normal_form.cc | 1 + src/relay/transforms/virtual_device_check.cc | 17 +++++++++++++--- src/relay/transforms/virtual_device_check.h | 15 ++++---------- 8 files changed, 46 insertions(+), 21 deletions(-) diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 3d889cdf6561..2ff601c307eb 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -41,6 +41,7 @@ #include "../op/call/call.h" #include "../op/memory/device_copy.h" #include "../transforms/device_aware_visitors.h" +#include "../transforms/virtual_device_check.h" #include "./te_compiler.h" #include "./utils.h" @@ -201,6 +202,8 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfunction_metadata_); }, config->host_se_scope)(mod); + relay::VirtualDeviceCheck()(lowered_mod); Optional main_func_info = lowered_mod->GetAttr("main_func_info"); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 93591d80dcee..5bd181172bf4 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -47,6 +47,7 @@ #include "../op/call/call.h" #include "../op/memory/device_copy.h" #include "../transforms/device_aware_visitors.h" +#include "../transforms/virtual_device_check.h" #include "./te_compiler_cache.h" #include "./utils.h" @@ -100,6 +101,7 @@ class TECompilerImpl : public TECompilerNode { } IRModule GetLoweredFunctions() { + VLOG(1) << "GetLoweredFunctions"; IRModule mod; // Extract lowered functions from the cache for (const auto& it : cache_) { @@ -631,7 +633,7 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { } call_lowered_attrs->metadata.Set("all_prim_shape_fn_vars", all_prim_shape_fn_vars); } - + // What if the "function" is a prim_fn return CallLowered(cfunc->prim_fn_var, std::move(visited_args), Attrs(call_lowered_attrs), type_args, std::move(span)); } @@ -1063,9 +1065,12 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr // GlobalVar, and calls updated (sticking with regular Relay Call). // - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, and calls updated // (using call_lowered convention). + relay::VirtualDeviceCheck()(module); + IRModule updated_module = LowerTensorExpr(module_name, compiler, std::move(process_fn), std::move(host_se_scope))(module); - + VLOG(1) << "Updated module: " << updated_module; + relay::VirtualDeviceCheck()(updated_module); // The Functions tagged with "Compiler" are now residing in the cache ready to be // compiled by LowerExternalFunctions. However we still need a record of them in the // IRModule so that the various executors can see which function names need to be @@ -1074,6 +1079,9 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr // Add the lowered functions. IRModule lowered_module = compiler->GetLoweredFunctions(); + VLOG(1) << "Got lowered functions"; + relay::VirtualDeviceCheck()(lowered_module); + VLOG(1) << "capturing " << lowered_module->functions.size() << " new lowered functions"; for (const auto& kv : lowered_module->functions) { if (updated_module->ContainGlobalVar(kv.first->name_hint)) { @@ -1085,6 +1093,8 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr } updated_module->Add(kv.first, kv.second); } + relay::VirtualDeviceCheck()(updated_module); + // Invoke external codegen for all Functions in the cache tagged with "Compiler", and // annotate the module with the resulting runtime modules. @@ -1128,7 +1138,7 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr } updated_module = WithAttr(updated_module, "op_weights", std::move(op_weights)); } - + relay::VirtualDeviceCheck()(lowered_module); return updated_module; } @@ -1171,7 +1181,9 @@ Pass LowerTEPass(const String& module_name, ProcessFn process_fn, SEScope host_s }; return tvm::transform::Sequential( - {tvm::relay::transform::RelayToTIRTargetHook(), + {relay::VirtualDeviceCheck(), + tvm::relay::transform::RelayToTIRTargetHook(), + relay::VirtualDeviceCheck(), tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {"InferType"}), InferType()}); } } // namespace tec diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 3feccebad4a2..43829f38ef1a 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -103,7 +103,6 @@ RELAY_REGISTER_OP("on_device") OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { // This needs to extract the correct se scope from the function. - std::cout << "GetOnDeviceProps for: " << call_node; if (call_node->op == OnDeviceOp()) { std::cout << "Is OnDevice op" << std::endl; ICHECK_EQ(call_node->args.size(), 1) << "on_device expects one argument"; diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc index c4e6c4d3064b..285c479dd95f 100644 --- a/src/relay/transforms/device_aware_visitors.cc +++ b/src/relay/transforms/device_aware_visitors.cc @@ -48,7 +48,7 @@ LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional& maybe_mod) } SEScope LexicalOnDeviceMixin::GetSEScope(const Expr& expr) const { - VLOG(1) << "GetSEScope for " << expr; + VLOG(1) << "GetSEScope for " << PrettyPrint(expr); OnDeviceProps props = GetOnDeviceProps(expr); VLOG(1) << "props.body.defined(): " << props.body.defined(); if (props.body.defined() && props.is_fixed) { @@ -82,6 +82,7 @@ SEScope LexicalOnDeviceMixin::GetSEScope(const Expr& expr) const { return function_node->virtual_device(); } } else { + VLOG(1) << "Checking expr_se_scopes_"; if (!expr_se_scopes_.empty()) { // Use the currently in-scope device type. return expr_se_scopes_.back(); @@ -101,6 +102,7 @@ void LexicalOnDeviceMixin::ExitFunctionBody() { void LexicalOnDeviceMixin::PushSEScope(const SEScope& se_scope) { if (se_scope->IsFullyUnconstrained()) { + // ICHECK(false) << "Should be fully constrained"; return; } expr_se_scopes_.emplace_back(se_scope); diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index e7ff8e59981d..6bc01774e46c 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -259,6 +259,7 @@ #include "../op/memory/device_copy.h" #include "../op/memory/on_device.h" #include "./device_domains.h" +#include "./virtual_device_check.h" namespace tvm { namespace relay { @@ -1086,6 +1087,7 @@ tvm::transform::Pass PlanDevices(CompilationConfig config) { std::vector passes; passes.emplace_back(Rewrite()); passes.emplace_back(PlanDevicesCore(std::move(config))); + passes.emplace_back(VirtualDeviceCheck()); return tvm::transform::Sequential(passes, "PlanDevices"); } diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 7a08ef9cba4d..79d8082b108d 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -293,6 +293,7 @@ class Fill : ExprFunctor, private transform::Lexi } else { // Keep track of expression and bound variable device types for lexically enclosing // sub-expressions. + ICHECK(f->virtual_device_.defined()) << "Virtual device should be defined for function " << PrettyPrint(GetRef(f)); PushSEScope(f->virtual_device()); for (const Var param : f->params) { PushBoundVar(param, param->virtual_device()); diff --git a/src/relay/transforms/virtual_device_check.cc b/src/relay/transforms/virtual_device_check.cc index a7c68755ed16..b50fd1c19c50 100644 --- a/src/relay/transforms/virtual_device_check.cc +++ b/src/relay/transforms/virtual_device_check.cc @@ -18,15 +18,22 @@ */ #include "./virtual_device_check.h" + #include namespace tvm { +namespace relay { + using tvm::relay::transform::CreateFunctionPass; using tvm::transform::PassContext; +void DeviceChecker::VisitExpr(const Expr& e) { + ExprVisitor::VisitExpr(e); +} + // TODO(@jroesch, @junru): we need to deal with unique spans for global/var. void DeviceChecker::VisitExpr_(const VarNode* op) { - ICHECK((op->virtual_device_.defined())) << "VarNode's virtual device should not be null"; + // ICHECK((op->virtual_device_.defined())) << "VarNode's virtual device should not be null"; } void DeviceChecker::VisitExpr_(const GlobalVarNode* op) {} void DeviceChecker::VisitExpr_(const ConstantNode* op) {} @@ -34,7 +41,10 @@ void DeviceChecker::VisitExpr_(const ConstantNode* op) {} void DeviceChecker::VisitExpr_(const TupleNode* op) { ExprVisitor::VisitExpr_(op); } void DeviceChecker::VisitExpr_(const FunctionNode* op) { - ICHECK(op->virtual_device_.defined()); + ICHECK(!op->virtual_device()->IsFullyConstrained()); + for (auto var : op->params) { + ICHECK(!op->virtual_device()->IsFullyConstrained()); + } ExprVisitor::VisitExpr_(op); } void DeviceChecker::VisitExpr_(const CallNode* op) { ExprVisitor::VisitExpr_(op); } @@ -62,7 +72,7 @@ void DeviceChecker::VisitType(const Type& t) {} void DeviceChecker::VisitClause(const Clause& c) {} void DeviceChecker::VisitPattern(const Pattern& c) {} -Pass VirtualDeviceCheck() { +tvm::transform::Pass VirtualDeviceCheck() { return CreateFunctionPass( [](const Function& func, const IRModule& mod, const PassContext& ctx) { ICHECK(ctx->diag_ctx) << "Diagnostic context must be set."; @@ -75,4 +85,5 @@ Pass VirtualDeviceCheck() { TVM_REGISTER_GLOBAL("VirtualDeviceCheck").set_body_typed([]() { return VirtualDeviceCheck(); }); +} // namespace relay } // namespace tvm \ No newline at end of file diff --git a/src/relay/transforms/virtual_device_check.h b/src/relay/transforms/virtual_device_check.h index 3539757dfed5..7be051806cc5 100644 --- a/src/relay/transforms/virtual_device_check.h +++ b/src/relay/transforms/virtual_device_check.h @@ -18,7 +18,7 @@ */ /*! - * \file span_check.h + * \file virtual_device_check.h * \brief Check that the Relay IR has correctly attached span information. */ @@ -32,15 +32,8 @@ #include #include -#include -#include -#include -#include - namespace tvm { - -using namespace tvm::relay; -using tvm::transform::Pass; +namespace relay { struct DeviceChecker : ExprVisitor { @@ -63,10 +56,10 @@ struct DeviceChecker : ExprVisitor { void VisitType(const Type& t) override; void VisitClause(const Clause& c) override; void VisitPattern(const Pattern& c) override; - void VisitSpan(const Span& span) override; }; -Pass VirtualDeviceCheck(); +tvm::transform::Pass VirtualDeviceCheck(); +} // namespace relay } // namespace tvm #endif // TVM_VIRTUAL_DEVICE_CHECK_H_