diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 31dec2204146..cdb8e52d2359 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -264,17 +264,9 @@ class Clause : public ObjectRef { }; /*! - * \brief Returns the clause with given properties. A null property denotes 'no change'. - * Returns clause if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param clause The clause to copy. - * \param opt_lhs The (optional) lhs for the copied clause. If none, ret_clause->lhs = clause->lhs. - * \param opt_rhs The (optional) rhs for the copied clause. If none, - * ret_clause->rhs = clause->rhs. - * \return If all - * properties are null or the same as the property in the input clause (i.e., opt_lhs is null or - * opt_lhs.value() == clause->lhs, etc.), then we return clause. Otherwise, we return a copy of - * clause with the different fields overwritten. (i.e., if opt_lhs.value() != clause->lhs, then - * ret_clause->lhs = opt_lhs.value()). + * \brief Returns \p clause with the given properties. A null property denotes 'no change'. + * Returns \p clause if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Clause WithFields(Clause clause, Optional opt_lhs = Optional(), Optional opt_rhs = Optional()); @@ -337,20 +329,9 @@ class Match : public Expr { }; /*! - * \brief Returns the match with given properties. A null property denotes 'no change'. - * Returns match if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param match The match to copy. - * \param opt_data The (optional) data for the copied match. If none, ret_match->data = match->data. - * \param opt_clauses The (optional) clauses for the copied match. If none, ret_match->clauses = - * match->clauses. - * \param opt_complete The (optional) complete for the copied match. If none, ret_match->complete = - * match->complete. - * \param opt_span The (optional) span for the copied match. If none, ret_match->span = match->span. - * \return If all properties are null or the same as the - * property in the input match (i.e., opt_clauses is null or opt_clauses.value() == match->clauses, - * etc.), then we return match. Otherwise, we return a copy of match with the different fields - * overwritten. (i.e., if opt_clauses.value() != match->clauses, then ret_match->clauses = - * opt_clauses.value()). + * \brief Returns \p match with the given properties. A null property denotes 'no change'. + * Returns \p match if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Match WithFields(Match match, Optional opt_data = Optional(), Optional> opt_clauses = Optional>(), diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index fe570806922f..6b014c8478d8 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -39,6 +39,16 @@ #include "./type.h" namespace tvm { + +/*! + * \brief Returns \p global_var with the given properties. A null property denotes 'no change'. + * Returns \p global_var if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +GlobalVar WithFields(GlobalVar global_var, Optional opt_name_hint = {}, + Optional opt_type = {}, Optional opt_virtual_device = {}, + Optional opt_span = {}); + namespace relay { using Expr = tvm::RelayExpr; @@ -97,8 +107,17 @@ class Constant : public Expr { TVM_DLL explicit Constant(runtime::NDArray data, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ConstantNode); }; +/*! + * \brief Returns \p constant with the given properties. A null property denotes 'no change'. + * Returns \p constant if all properties are unchanged. Otherwise, returns a copy with the new + * fields. + */ +Constant WithFields(Constant constant, Optional opt_data = {}, + Optional opt_virtual_device = {}, Optional opt_span = {}); + /*! \brief Tuple of multiple Exprs */ class Tuple; /*! \brief Tuple container */ @@ -149,15 +168,9 @@ class Tuple : public Expr { }; /*! - * \brief Returns the tuple with given properties. A null property denotes 'no change'. - * Returns this if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param tuple The tuple to copy - * \param opt_fields The (optional) fields for the copied tuple. If none, ret_tuple->fields = - * tuple->fields. - * \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none, - * ret_tuple->virtual_device = tuple->virtual_device. - * \param opt_span The (optional) span for the copied tuple. If none, - * ret_tuple->span = tuple->span. + * \brief Returns \p tuple with the given properties. A null property denotes 'no change'. + * Returns \p tuple if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Tuple WithFields(Tuple tuple, Optional> opt_fields = Optional>(), Optional opt_virtual_device = Optional(), @@ -251,19 +264,9 @@ class Var : public Expr { }; /*! - * \brief Returns the var with given properties. A null property denotes 'no change'. - * Returns var if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param var The var to copy. - * \param opt_vid The (optional) vid for the copied var. If none, ret_var->vid = var->vid. - * \param opt_type_annotation The (optional) type_annotation for the copied var. If none, - * ret_var->type_annotation = var->type_annotation. - * \param opt_virtual_device The (optional) virtual_device for the copied tuple. If none, - * ret_tuple->virtual_device = tuple->virtual_device. - * \param opt_span The (optional) span for the copied var. If none, ret_var->span = var->span. - * \return If all properties are null or the same as the property in the input var (i.e., opt_vid is - * null or opt_vid.value() == var->vid, etc.), then we return var. Otherwise, we return a copy of - * call with the different fields overwritten. (i.e., if opt_vid.value() != var->vid, then - * ret_var->vid = opt_.value()). + * \brief Returns \p vor with the given properties. A null property denotes 'no change'. + * Returns \p var if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Var WithFields(Var var, Optional opt_vid = Optional(), Optional opt_type_annotation = Optional(), @@ -374,22 +377,9 @@ class Call : public Expr { }; /*! - * \brief Returns the call with given properties. A null property denotes 'no change'. - * Returns call if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param call The call to copy. - * \param opt_op The (optional) op for the copied call. If none, ret_call->op = call->op. - * \param opt_args The (optional) args for the copied call. If none, ret_call->args = call->args. - * \param opt_attrs The (optional) attrs for the copied call. If none, ret_call->attrs = - * call->attrs. - * \param opt_type_args The (optional) type args for the copied call. If none, - * ret_call->type_args = call->type_args. - * \param opt_virtual_device The (optional) virtual_device for the copied call. If none, - * ret_call->virtual_device = call->virtual_device. - * \param opt_span The (optional) span for the copied call. If none, ret_call->span = call->span. - * \return If all properties are null or the same as the property in the input call (i.e., opt_op is - * null or opt_op.value() == call->op, etc.), then we return call. Otherwise, we return a copy of - * call with the different fields overwritten. (i.e., if opt_op.value() != call->op, then - * ret_call->op = opt_op.value()). + * \brief Returns \p call with the given properties. A null property denotes 'no change'. + * Returns \p call if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Call WithFields(Call call, Optional opt_op = Optional(), Optional> opt_args = Optional>(), @@ -475,19 +465,9 @@ class Let : public Expr { }; /*! - * \brief Returns the let with given properties. A null property denotes 'no change'. - * Returns let if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param let The let to copy. - * \param opt_var The (optional) var for the copied let. If none, ret_let->op = let->op. - * \param opt_value The (optional) value for the copied let. If none, ret_let->args = let->args. - * \param opt_body The (optional) body for the copied let. If none, ret_let->attrs = let->attrs. - * \param opt_virtual_device The (optional) virtual_device for the copied let. If none, - * ret_let->virtual_device = let->virtual_device. - * \param opt_span The (optional) span for the copied let. If none, ret_let->span = let->span. - * \return If all properties are null or the same as the property in the input let (i.e., opt_var is - * null or opt_var.value() == let->var, etc.), then we return let. Otherwise, we return a copy of - * let with the different fields overwritten. (i.e., if opt_var.value() != let->var, then - * ret_let->var = opt_var.value()). + * \brief Returns \p let with the given properties. A null property denotes 'no change'. + * Returns \p let if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Let WithFields(Let let, Optional opt_var = Optional(), Optional opt_value = Optional(), @@ -559,23 +539,9 @@ class If : public Expr { }; /*! - * \brief Returns the if_expr with given properties. A null property denotes 'no change'. - * Returns if_expr if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param if_expr The if expression to copy. - * \param opt_cond The (optional) cond for the copied if_expr. If none, ret_if->cond = - * if_expr->cond. - * \param opt_true_branch The (optional) true_branch for the copied if_expr. If none, - * ret_if->true_branch = ret_if->false_branch. - * \param opt_false_branch The (optional) false_branch - * for the copied if_expr. If none, ret_if->false_branch = if_expr->false_branch. - * \param opt_virtual_device The (optional) virtual_device for the copied if_expr. If none, - * ret_if->virtual_device = if_expr->virtual_device. - * \param opt_span The (optional) span for the copied if_expr. If none, - * ret_if->span = if_expr->span. - * \return If all properties are null or the same as the property in - * the input if_expr (i.e., opt_cond is null or opt_cond.value() == if_expr->cond, etc.), then we - * return if_expr. Otherwise, we return a copy of if_expr with the different fields overwritten. - * (i.e., if opt_cond.value() != if_expr->cond, then ret_if->cond = opt_cond.value()). + * \brief Returns \p if_expr with the given properties. A null property denotes 'no change'. + * Returns \p if_expr if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ If WithFields(If if_expr, Optional opt_cond = Optional(), Optional opt_true_branch = Optional(), @@ -628,22 +594,9 @@ class TupleGetItem : public Expr { }; /*! - * \brief Returns the tuple_get_item with given properties. A null property denotes 'no change'. - * Returns if_expr if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param tuple_get_item The tuple_get_item to copy. - * \param opt_tuple The (optional) tuple for the copied tuple_get_item. If none, - * ret_tuple_get_item->tuple = tuple_get_item->tuple. - * \param opt_index The (optional) index for the copied tuple_get_item. If none, - * ret_tuple_get_item->index = tuple_get_item->index. - * \param opt_virtual_device The (optional) virtual_device for the copied tuple_get_item. - * If none, ret_tuple_get_item->virtual_device = tuple_get_item->virtual_device. - * \param opt_span The (optional) span for the copied tuple_get_item. If none, - * ret_tuple_get_item->span = tuple_get_item->span. - * \return If all properties are null or the same as the property in the input tuple_get_item - * (i.e., opt_tuple is null or opt_tuple.value() == tuple_get_item->tuple, etc.), then we return - * tuple_get_item. Otherwise, we return a copy of tuple_get_item with the different fields - * overwritten. (i.e., if opt_tuple.value() != tuple_get_item->tuple, then - * ret_tuple_get_item->tuple = opt_tuple.value()). + * \brief Returns \p tuple_get_item with the given properties. A null property denotes 'no change'. + * Returns \p tuple_get_item if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), Optional opt_index = Optional(), @@ -692,21 +645,9 @@ class RefCreate : public Expr { }; /*! - * \brief Returns the ref create with given properties. A null property denotes 'no change'. - * Returns ref_create if all properties are unchanged. Otherwise, returns a copy with the new + * \brief Returns \p ref_create with the given properties. A null property denotes 'no change'. + * Returns \p ref_crete if all properties are unchanged. Otherwise, returns a copy with the new * fields. - * \param ref_create The ref_create to copy. - * \param opt_value The (optional) value for the copied ref_create. If none, - * ret_ref_create->value = ref_create->value. - * \param opt_virtual_device The (optional) virtual_device for the copied ref_create. If none, - * ret_ref_create->virtual_device = ref_create->virtual_device. - * \param opt_span The (optional) span for the copied ref_create. If none, - * ret_ref_create->span = ref_create->span. - * \return If all properties are null or the same as the property in the input ref_create - * (i.e., opt_value is null or opt_value.value() == ref_create->value, etc.), then we return - * ref_create. Otherwise, we return a copy of ref_create with the different fields overwritten. - * (i.e., if opt_value.value() != ref_create->value, then - * ret_ref_create->value = opt_value.value()). */ RefCreate WithFields(RefCreate ref_create, Optional opt_value = Optional(), Optional opt_virtual_device = Optional(), @@ -754,20 +695,9 @@ class RefRead : public Expr { }; /*! - * \brief Returns the ref read with given properties. A null property denotes 'no change'. - * Returns ref_read if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param ref_read The ref_read to copy. - * \param opt_ref The (optional) ref for the copied ref_read. If none, ret_ref_read->ref = - * ref_read->ref. - * \param opt_virtual_device - * The (optional) virtual_device for the copied ref_read. If none, ret_ref_read->virtual_device = - * ref_read->virtual_device. - * \param opt_span The (optional) span for the copied ref_read. If none, ret_ref_read->span = - * ref_read->span. - * \return If all properties are null or the same as the property in the input - * ref_read (i.e., opt_ref is null or opt_ref.value() == ref_read->ref, etc.), then we return - * ref_read. Otherwise, we return a copy of ref_read with the different fields overwritten. (i.e., - * if opt_ref.value() != ref_read->ref, then ret_ref_read->ref = opt_ref.value()). + * \brief Returns \p ref_read with the given properties. A null property denotes 'no change'. + * Returns \p ref_read if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ RefRead WithFields(RefRead ref_read, Optional opt_ref = Optional(), Optional opt_virtual_device = Optional(), @@ -820,22 +750,9 @@ class RefWrite : public Expr { }; /*! - * \brief Returns the ref write with given properties. A null property denotes 'no change'. - * Returns ref_write if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param ref_write The ref_write to copy. - * \param opt_ref The (optional) ref for the copied ref_write. If none, - * ret_ref_write->ref = ref_write->ref. - * \param opt_value The (optional) value for the copied ref_write. If none, - * ret_ref_write->value = ref_write->value. - * \param opt_virtual_device - * The (optional) virtual_device for the copied ref_write. If none, ret_ref_write->virtual_device = - * ref_write->virtual_device. - * \param opt_span The (optional) span for the copied ref_write. If none, ret_ref_write->span = - * ref_write->span. - * \return If all properties are null or the same as the property in the input ref_write (i.e., - * opt_ref is null or opt_ref.value() == ref_write->ref, etc.), then we return ref_write. Otherwise, - * we return a copy of ref_write with the different fields overwritten. (i.e., if ref_write.value() - * != ref_write->ref, then ret_ref_write->ref = opt_ref.value()). + * \brief Returns \p ref_write with the given properties. A null property denotes 'no change'. + * Returns \p ref_write if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ RefWrite WithFields(RefWrite ref_write, Optional opt_ref = Optional(), Optional opt_value = Optional(), diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index d8f575dfdf48..280a1f8a6c29 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -240,6 +240,8 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { */ explicit MixedModeVisitor(int visit_limit = 1); + using ExprVisitor::VisitExpr_; + /*! * \brief VisitExpr is finalized to preserve call expansion of dataflow regions */ diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 052d04fe2411..874d4f233416 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -121,28 +121,9 @@ class Function : public BaseFunc { }; /*! - * \brief Returns the function with given properties. A null property denotes 'no change'. - * Returns function if all properties are unchanged. Otherwise, returns a copy with the new fields. - * \param function The function to copy. - * \param opt_params The (optional) params for the copied function. If none, - * ret_function->params = function->params. - * \param opt_body The (optional) body for the copied function. If none, - * ret_function->body = function->body. - * \param opt_ret_type The (optional) return type for the copied function. If none, - * ret_function->ret_type = function->ret_type. - * \param opt_ty_params The (optional) type params for the copied function. If none, - * ret_function->type_params = function->type_params. - * \param opt_attrs - * The (optional) attributes for the copied function. If none, - * ret_function->attrs = function->attrs. - * \param opt_virtual_device The (optional) virtual_device for the copied function. If none, - * ret_function->virtual_device = function->virtual_device. - * \param opt_span The (optional) span for the copied function. If none, - * ret_function->span = function->span. - * \return If all properties are null or the same as the property in the input function - * (i.e., opt_params is null or opt_params.value() == function->params, etc.), then we return - * function. Otherwise, we return a copy of function with the different fields overwritten. (i.e., - * if opt_params.value() != function->params, then ret_function->params = opt_params.value()). + * \brief Returns \p function with the given properties. A null property denotes 'no change'. + * Returns \p function if all properties are unchanged. Otherwise, returns a copy with the new + * fields. */ Function WithFields(Function function, Optional> opt_params = Optional>(), Optional opt_body = Optional(), diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index db36d02896a2..97c039ee29cf 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -602,7 +602,7 @@ class CodegenCutlass : public backend::MemoizedExprTranslatorbody.as(), 0, {"nn.dense"}); + const auto* dense_call = GetRootCall(callee->body.as(), 0, "nn.dense"); return GenerateBody(dense_call, "cutlass_dense", GetArgumentNames(caller), DenseArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.dense_bias") { @@ -637,11 +637,11 @@ class CodegenCutlass : public backend::MemoizedExprTranslatorbody.as(), 0, {"nn.batch_matmul"}); + GetRootCall(callee->body.as(), 0, "nn.batch_matmul"); return GenerateBody(batch_matmul_call, "cutlass_batch_matmul", GetArgumentNames(caller), BatchMatmulArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.conv2d") { - const auto* conv2d_call = GetRootCall(callee->body.as(), 0, {"nn.conv2d"}); + const auto* conv2d_call = GetRootCall(callee->body.as(), 0, "nn.conv2d"); return GenerateBody(conv2d_call, "cutlass_conv2d", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_))); } else if (pattern_name == "cutlass.conv2d_bias") { @@ -704,13 +704,12 @@ class CodegenCutlass : public backend::MemoizedExprTranslatorbody.as(), 0, {"nn.conv2d_transpose"}); + const auto* conv2d_call = GetRootCall(callee->body.as(), 0, "nn.conv2d_transpose"); return GenerateBody(conv2d_call, "cutlass_conv2d_transpose", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_), true, false)); } else if (pattern_name == "cutlass.conv2d_backward_weight") { const auto* conv2d_call = - GetRootCall(callee->body.as(), 0, {"nn.conv2d_backward_weight"}); + GetRootCall(callee->body.as(), 0, "nn.conv2d_backward_weight"); return GenerateBody(conv2d_call, "cutlass_conv2d_backward_weight", GetArgumentNames(caller), Conv2dArgs(std::ref(attrs_), false, true)); } diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index fc76577bd7c0..85892e8223af 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -27,6 +27,26 @@ namespace tvm { +GlobalVar WithFields(GlobalVar global_var, Optional opt_name_hint, Optional opt_type, + Optional opt_virtual_device, Optional opt_span) { + String name_hint = opt_name_hint.value_or(global_var->name_hint); + Type type = opt_type.value_or(global_var->checked_type()); + VirtualDevice virtual_device = opt_virtual_device.value_or(global_var->virtual_device()); + Span span = opt_span.value_or(global_var->span); + bool all_fields_unchanged = + name_hint.same_as(global_var->name_hint) && type.same_as(global_var->checked_type()) && + virtual_device.same_as(global_var->virtual_device()) && span.same_as(global_var->span); + if (!all_fields_unchanged) { + GlobalVarNode* cow_global_var_node = global_var.CopyOnWrite(); + cow_global_var_node->name_hint = name_hint; + cow_global_var_node->checked_type_ = type; + cow_global_var_node->virtual_device_ = virtual_device; + cow_global_var_node->span = span; + } + + return global_var; +} + VirtualDevice RelayExprNode::virtual_device() const { if (!this->virtual_device_.defined()) { // virtual_device_ should always be defined, unless we imported this node from JSON using an old @@ -77,6 +97,25 @@ TensorType ConstantNode::tensor_type() const { return TensorType(shape, dtype); } +Constant WithFields(Constant constant, Optional opt_data, + Optional opt_virtual_device, Optional opt_span) { + runtime::NDArray data = opt_data.value_or(constant->data); + VirtualDevice virtual_device = opt_virtual_device.value_or(constant->virtual_device()); + Span span = opt_span.value_or(constant->span); + + bool all_fields_unchanged = data.same_as(constant->data) && + virtual_device.same_as(constant->virtual_device()) && + span.same_as(constant->span); + + if (!all_fields_unchanged) { + ConstantNode* cow_constant_node = constant.CopyOnWrite(); + cow_constant_node->data = data; + cow_constant_node->virtual_device_ = virtual_device; + cow_constant_node->span = span; + } + return constant; +} + Tuple::Tuple(tvm::Array fields, Span span) { ObjectPtr n = make_object(); n->fields = std::move(fields); @@ -90,6 +129,7 @@ TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array fields, Span span) { return Tuple(fields, span); }); + Tuple WithFields(Tuple tuple, Optional> opt_fields, Optional opt_virtual_device, Optional opt_span) { Array fields = opt_fields.value_or(tuple->fields); @@ -189,6 +229,7 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span s Call WithFields(Call call, Optional opt_op, Optional> opt_args, Optional opt_attrs, Optional> opt_type_args, Optional opt_virtual_device, Optional opt_span) { + // Collect new values for fields. Expr op = opt_op.value_or(call->op); Array args = opt_args.value_or(call->args); Attrs attrs = opt_attrs.value_or(call->attrs); @@ -196,37 +237,30 @@ Call WithFields(Call call, Optional opt_op, Optional> opt_args VirtualDevice virtual_device = opt_virtual_device.value_or(call->virtual_device()); Span span = opt_span.value_or(call->span); + // Check if anything changed. bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && virtual_device.same_as(call->virtual_device()) && span.same_as(call->span); - - // Check that the args are unchanged if (unchanged) { - bool all_args_unchanged = true; if (args.size() == call->args.size()) { for (size_t i = 0; i < args.size(); i++) { - all_args_unchanged &= args[i].same_as(call->args[i]); + unchanged &= args[i].same_as(call->args[i]); } } else { - all_args_unchanged = false; + unchanged = false; } - unchanged &= all_args_unchanged; } - - // Check that the type_args are unchanged if (unchanged) { - bool all_type_args_unchanged = true; if (type_args.size() == call->type_args.size()) { for (size_t i = 0; i < type_args.size(); i++) { - all_type_args_unchanged &= type_args[i].same_as(call->type_args[i]); + unchanged &= type_args[i].same_as(call->type_args[i]); } } else { - all_type_args_unchanged = false; + unchanged = false; } - - unchanged &= all_type_args_unchanged; } if (!unchanged) { + // If call is only references, update it in place. Otherwise copy and update. CallNode* cow_call_node = call.CopyOnWrite(); cow_call_node->op = op; cow_call_node->args = args; diff --git a/tests/cpp/relay/with_fields_test.cc b/tests/cpp/relay/with_fields_test.cc new file mode 100644 index 000000000000..48e04c259bb5 --- /dev/null +++ b/tests/cpp/relay/with_fields_test.cc @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and lixmitations + * under the License. + */ + +/*! + * \brief Proof-of-concept unit tests for the family of WithFields helpers. + * Only Call, GlobalVar and Constant are currently tested. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace { + +IRModule TestIRModule() { + return parser::ParseModule("string", + R"( + #[version = "0.0.5"] + def @main(%data : Tensor[(1, 304, 128, 128), float32], + %weight1 : Tensor[(304, 1, 3, 3), float32], + %bias1 : Tensor[(304), float32], + %weight2 : Tensor[(256, 304, 1, 1), float32], + %bias2 : Tensor[(256), float32]) -> Tensor[(1, 256, 128, 128), float32] { + %0 = nn.conv2d(%data, %weight1, padding=[1, 1, 1, 1], groups=304, channels=304, kernel_size=[3, 3]); + %1 = nn.bias_add(%0, %bias1); + %2 = nn.relu(%1); + %3 = nn.conv2d(%2, %weight2, padding=[0, 0, 0, 0], channels=256, kernel_size=[1, 1]); + %4 = nn.bias_add(%3, %bias2); + nn.relu(%4) + } + )"); +} + +Function TestFunction() { return Downcast(TestIRModule()->Lookup("main")); } +Call TestCall() { return Downcast(TestFunction()->body); } +GlobalVar TestGlobalVar() { return TestIRModule()->GetGlobalVar("main"); } +VirtualDevice TestVirtualDevice() { return VirtualDevice::ForDevice({kDLCUDA, 3}); } +Span TestSpan() { return Span(SourceName::Get("foo"), 3, 4, 6, 42); } +Constant TestConstant() { + return Constant(runtime::NDArray::Empty({}, DataType::Int(32), {kDLCPU, 0})); +} + +// +// Call +// + +TEST(WithFields, Call_Noop) { + Call call = TestCall(); + Call result = WithFields(call); + ASSERT_TRUE(result.same_as(call)); +} + +TEST(WithFields, Call_Op) { + Call call = TestCall(); + Op new_op = Op::Get("tanh"); + Call result = WithFields(call, new_op); + ASSERT_FALSE(result.same_as(call)); + ASSERT_FALSE(call->op.same_as(new_op)); + ASSERT_TRUE(result->op.same_as(new_op)); +} + +TEST(WithFields, Call_Args) { + Call call = TestCall(); + Array new_args = {Tuple(Array())}; + Call result = WithFields(call, /*opt_op=*/{}, new_args); + ASSERT_FALSE(result.same_as(call)); + ASSERT_FALSE(call->args.same_as(new_args)); + ASSERT_TRUE(result->args.same_as(new_args)); +} + +TEST(WithFields, Call_Attrs) { + Call call = TestCall(); + Attrs new_attrs = DictAttrs(Map()); + Call result = WithFields(call, /*opt_op=*/{}, /*opt_args=*/{}, new_attrs); + ASSERT_FALSE(result.same_as(call)); + ASSERT_FALSE(call->attrs.same_as(new_attrs)); + ASSERT_TRUE(result->attrs.same_as(new_attrs)); +} + +TEST(WithFields, Call_TypeArgs) { + Call call = TestCall(); + Array new_type_args; + Call result = WithFields(call, /*opt_op=*/{}, /*opt_args=*/{}, /*opt_attrs=*/{}, new_type_args); + ASSERT_FALSE(result.same_as(call)); + ASSERT_FALSE(call->type_args.same_as(new_type_args)); + ASSERT_TRUE(result->type_args.same_as(new_type_args)); +} + +TEST(WithFields, Call_VirtualDevice) { + Call call = TestCall(); + VirtualDevice new_virtual_device = TestVirtualDevice(); + Call result = WithFields(call, /*opt_op=*/{}, /*opt_args=*/{}, /*opt_attrs=*/{}, + /*opt_type_args=*/{}, new_virtual_device); + ASSERT_FALSE(result.same_as(call)); + ASSERT_FALSE(call->virtual_device().same_as(new_virtual_device)); + ASSERT_TRUE(result->virtual_device().same_as(new_virtual_device)); +} + +TEST(WithFields, Call_Span) { + Call call = TestCall(); + Span new_span = TestSpan(); + Call result = WithFields(call, /*opt_op=*/{}, /*opt_args=*/{}, /*opt_attrs=*/{}, + /*opt_type_args=*/{}, /*opt_virtual_device=*/{}, new_span); + ASSERT_FALSE(result.same_as(call)); + ASSERT_FALSE(call->span.same_as(new_span)); + ASSERT_TRUE(result->span.same_as(new_span)); +} + +// +// GlobalVar +// + +TEST(WithFields, GlobalVar_Noop) { + GlobalVar gv = TestGlobalVar(); + GlobalVar result = WithFields(gv); + ASSERT_TRUE(result.same_as(gv)); +} + +TEST(WithFields, GlobalVar_Name) { + GlobalVar gv = TestGlobalVar(); + String new_name("foo"); + GlobalVar result = WithFields(gv, new_name); + ASSERT_FALSE(result.same_as(gv)); + ASSERT_FALSE(gv->name_hint.same_as(new_name)); + ASSERT_TRUE(result->name_hint.same_as(new_name)); +} + +TEST(WithFields, GlobalVar_Type) { + GlobalVar gv = TestGlobalVar(); + Type new_type = TupleType(Array()); + GlobalVar result = WithFields(gv, /*opt_name_hint=*/{}, new_type); + ASSERT_FALSE(result.same_as(gv)); + ASSERT_FALSE(gv->checked_type().same_as(new_type)); + ASSERT_TRUE(result->checked_type().same_as(new_type)); +} + +TEST(WithFields, GlobalVar_VirtualDevice) { + GlobalVar gv = TestGlobalVar(); + VirtualDevice new_virtual_device = TestVirtualDevice(); + GlobalVar result = WithFields(gv, /*opt_name_hint=*/{}, /*opt_type=*/{}, new_virtual_device); + ASSERT_FALSE(result.same_as(gv)); + ASSERT_FALSE(gv->virtual_device().same_as(new_virtual_device)); + ASSERT_TRUE(result->virtual_device().same_as(new_virtual_device)); +} + +TEST(WithFields, GlobalVar_Span) { + GlobalVar gv = TestGlobalVar(); + Span new_span = TestSpan(); + GlobalVar result = + WithFields(gv, /*opt_name_hint=*/{}, /*opt_type=*/{}, /*opt_virtual_device=*/{}, new_span); + ASSERT_FALSE(result.same_as(gv)); + ASSERT_FALSE(gv->span.same_as(new_span)); + ASSERT_TRUE(result->span.same_as(new_span)); +} + +// +// Constant +// + +TEST(WithFields, Constant_Noop) { + Constant constant = TestConstant(); + Constant result = WithFields(constant); + ASSERT_TRUE(result.same_as(constant)); +} + +TEST(WithFields, Constant_Data) { + Constant constant = TestConstant(); + runtime::NDArray new_data = runtime::NDArray::Empty({}, DataType::Float(32), {kDLCPU, 0}); + Constant result = WithFields(constant, new_data); + ASSERT_FALSE(result.same_as(constant)); + ASSERT_FALSE(constant->data.same_as(new_data)); + ASSERT_TRUE(result->data.same_as(new_data)); +} + +TEST(WithFields, Constant_VirtualDevice) { + Constant constant = TestConstant(); + VirtualDevice new_virtual_device = TestVirtualDevice(); + Constant result = WithFields(constant, /*opt_data=*/{}, new_virtual_device); + ASSERT_FALSE(result.same_as(constant)); + ASSERT_FALSE(constant->virtual_device().same_as(new_virtual_device)); + ASSERT_TRUE(result->virtual_device().same_as(new_virtual_device)); +} + +TEST(WithFields, Constant_Span) { + Constant constant = TestConstant(); + Span new_span = TestSpan(); + Constant result = WithFields(constant, /*opt_data=*/{}, /*opt_virtual_device=*/{}, new_span); + ASSERT_FALSE(result.same_as(constant)); + ASSERT_FALSE(constant->span.same_as(new_span)); + ASSERT_TRUE(result->span.same_as(new_span)); +} + +} // namespace +} // namespace relay +} // namespace tvm