diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index 9666475b8039..a6bcc28dad43 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -86,6 +86,7 @@ def _initialize_virtual_device(item, _): "relay.RefRead": _initialize_virtual_device, "relay.RefWrite": _initialize_virtual_device, "relay.Match": _initialize_virtual_device, + "relay.Constant": _initialize_virtual_device, } return create_updater(node_map, "0.8", "0.9") diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc index 61f215a7d88c..9b724034ccf2 100644 --- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc @@ -67,8 +67,8 @@ class ExtractConstantsMutator : public MixedModeMutator { auto new_body = VisitExpr(func->body); functions_.pop_back(); if (function_to_constants_[func].size()) { - func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_), - func->attrs); + func = WithFields(func, FreeVars(new_body), new_body, func->ret_type, + FreeTypeVars(new_body, mod_), func->attrs); } return std::move(func); } @@ -159,8 +159,7 @@ IRModule ExtractConstants(const IRModule& mod) { auto new_main_body = extract_constants.VisitExpr(main_func->body); if (!new_main_body.same_as(main_func->body)) { auto main_var = mod->GetGlobalVar("main"); - auto new_main_func = Function(main_func->params, new_main_body, main_func->ret_type, - main_func->type_params, main_func->attrs); + Function new_main_func = WithFields(main_func, main_func->params, new_main_body); mod->Update(main_var, new_main_func); } return mod; diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index b8744247e9a6..f366e4ab2635 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -46,13 +46,8 @@ class RelayToTIRVisitor : public MixedModeMutator { IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); - BaseFunc main = ir_module_->Lookup(main_global_var); - Function main_func = GetRef(main.as()); - - // Copy everything across and mutate the body - Function mutated_main = - Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, - main_func->type_params, main_func->attrs, main_func->span); + Function main = Downcast(ir_module_->Lookup(main_global_var)); + Function mutated_main = WithFields(main, main->params, VisitExpr(main->body)); ir_module_->Update(main_global_var, mutated_main); diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index d618a4971189..0fdbb7063e3f 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -56,12 +56,8 @@ class RelayToTIRMutator : public MixedModeMutator { IRModule operator()() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); - Function main_func = Downcast(ir_module_->Lookup(main_global_var)); - - // Copy everything across and mutate the body - Function mutated_main = - Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, - main_func->type_params, main_func->attrs, main_func->span); + Function main = Downcast(ir_module_->Lookup(main_global_var)); + Function mutated_main = WithFields(main, main->params, VisitExpr(main->body)); ir_module_->Update(main_global_var, mutated_main); ir_module_ = WithAttr(ir_module_, "device_contexts", device_contexts_); diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 89b325f51a0c..6794594b5ba4 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -43,13 +43,8 @@ class ConvertAddToSubtract : public MixedModeMutator { IRModule Mutate() { GlobalVar main_global_var = ir_module_->GetGlobalVar("main"); - BaseFunc main = ir_module_->Lookup(main_global_var); - Function main_func = GetRef(main.as()); - - // Copy everything across and mutate the body - Function mutated_main = - Function(main_func->params, VisitExpr(main_func->body), main_func->ret_type, - main_func->type_params, main_func->attrs, main_func->span); + Function main = GetRef(ir_module_->Lookup(main_global_var).as()); + Function mutated_main = WithFields(main, main->params, VisitExpr(main->body)); ir_module_->Update(main_global_var, mutated_main); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 3ff6076473f1..3000ef9640f3 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -100,6 +100,7 @@ class TECompilerImpl : public TECompilerNode { } IRModule GetLoweredFunctions() { + VLOG(1) << "GetLoweredFunctions"; IRModule mod; // Extract lowered functions from the cache for (const auto& it : cache_) { @@ -164,8 +165,15 @@ class TECompilerImpl : public TECompilerNode { for (const auto& kv2 : kv1.second->cached_func->funcs->functions) { if (const auto* function_node = kv2.second.as()) { // Abandon the existing function annotations. - Function function(function_node->params, function_node->body, function_node->ret_type, - function_node->type_params, /*attrs=*/{}, function_node->span); + + // Unfortuantely, Optional() is indistinguishable from + // NullValue(), and DictAttrs() is nullptr, so to erase the attributes, we + // need pass in DictAttrs()), which is a DictAttrs containing no + // attributes. + Function function = + WithFields(GetRef(function_node), function_node->params, + function_node->body, function_node->ret_type, function_node->type_params, + /* erase attributes */ DictAttrs(Map())); // Mark function as 'extern' using the "ExternalSymbol" attribute. function = WithAttr(std::move(function), attr::kExternalSymbol, kv2.first->name_hint); module->Add(kv2.first, function); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 0457459b3847..f2bd9e6b9a8a 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -102,8 +102,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { if (function_nesting() == 1) { // We don't need to lift global functions. - return Function(func_node->params, VisitExpr(func_node->body), func_node->ret_type, - func_node->type_params, func_node->attrs, func_node->span); + return WithFields(GetRef(func_node), func_node->params, VisitExpr(func_node->body)); } auto name = GenerateName(func); @@ -188,8 +187,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { // construct the "closure" function with fully annotated arguments, no longer relying // on type inference. size_t before_arity = body->params.size(); - auto rebound_body = Function(func->params, Bind(body->body, rebinding_map), func->ret_type, - func->type_params, func->attrs, func->span); + auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map)); size_t after_arity = rebound_body->params.size(); CHECK_EQ(before_arity, after_arity); lifted_func = diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 73ae3faf7078..fc76577bd7c0 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -45,6 +45,7 @@ using namespace tvm::runtime; Constant::Constant(runtime::NDArray data, Span span) { ObjectPtr n = make_object(); n->data = std::move(data); + n->virtual_device_ = VirtualDevice::FullyUnconstrained(); n->span = std::move(span); data_ = std::move(n); } diff --git a/src/relay/quantize/annotate.cc b/src/relay/quantize/annotate.cc index 3def616e9423..c704bcbc466b 100644 --- a/src/relay/quantize/annotate.cc +++ b/src/relay/quantize/annotate.cc @@ -98,7 +98,7 @@ Pass QuantizeAnnotate() { for (const auto& x : FreeVars(func)) { new_params.push_back(x); } - return Function(new_params, func->body, func->ret_type, func->type_params, func->attrs); + return WithFields(func, new_params); }; return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {}); } diff --git a/src/relay/quantize/calibrate.cc b/src/relay/quantize/calibrate.cc index 0ac445295496..21ed61187c38 100644 --- a/src/relay/quantize/calibrate.cc +++ b/src/relay/quantize/calibrate.cc @@ -152,8 +152,13 @@ class StatsCollector : private ExprMutator { const FunctionNode* func = new_e.as(); ICHECK(func) << "Input shoule be Function"; Expr new_body = Tuple(std::move(profile_data_)); - return Function(FreeVars(new_body), new_body, NullValue(), func->type_params, - func->attrs); + Function ret_func = WithFields(GetRef(func), FreeVars(new_body), new_body); + + // We are changing the function's ret_type to an empty type. Unfortunately, Optional() is + // indistinguishable from NullValue(), so we can't express "update to nullptr" in + // WithFields. + ret_func.CopyOnWrite()->ret_type = NullValue(); + return ret_func; } private: diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 6e4ab88ea326..3f1985b7ddfa 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -295,7 +295,7 @@ class AnnotateTargetRewriter : public ExprRewriter { func = Downcast(post); new_body = InsertCompilerEndAndPropogateTarget(func->body); } - return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs); + return WithFields(func, func->params, new_body); } Expr Rewrite_(const LetNode* op, const Expr& post) override { diff --git a/src/relay/transforms/convert_sparse_conv2d.cc b/src/relay/transforms/convert_sparse_conv2d.cc index 3f2c25e988f9..f2af290f3e22 100644 --- a/src/relay/transforms/convert_sparse_conv2d.cc +++ b/src/relay/transforms/convert_sparse_conv2d.cc @@ -292,12 +292,12 @@ Pass Conv2dToSparse(const Array& weight_name, const Array(Conv2dToSparse(f, weight_name, weight_shape, layout, kernel_size)); Array sparse_params = FreeVars(f0); - auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); + auto f1 = WithFields(f0, sparse_params); Array params = FreeVars(f1); for (const auto& var : sparse_params) { params.push_back(var); } - return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); + return WithFields(f1, params); }; return CreateFunctionPass(pass_func, 4, "Conv2dToSparse", {"DeadCodeElimination"}); } diff --git a/src/relay/transforms/convert_sparse_dense.cc b/src/relay/transforms/convert_sparse_dense.cc index 26a4d487196d..faba366eca49 100644 --- a/src/relay/transforms/convert_sparse_dense.cc +++ b/src/relay/transforms/convert_sparse_dense.cc @@ -135,12 +135,12 @@ Pass DenseToSparse(const Array& weight_name, // Remove FreeVar warnings auto f0 = Downcast(DenseToSparse(f, weight_name, weight_shape)); Array sparse_params = FreeVars(f0); - auto f1 = Function(sparse_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); + auto f1 = WithFields(f0, sparse_params); Array params = FreeVars(f1); for (const auto& var : sparse_params) { params.push_back(var); } - return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); + return WithFields(f1, params); }; return CreateFunctionPass(pass_func, 4, "DenseToSparse", {"DeadCodeElimination"}); } diff --git a/src/relay/transforms/de_duplicate.cc b/src/relay/transforms/de_duplicate.cc index b3e88376abcb..23e147d5d4c4 100644 --- a/src/relay/transforms/de_duplicate.cc +++ b/src/relay/transforms/de_duplicate.cc @@ -82,16 +82,17 @@ Expr DeDup(const Expr& e) { Type VisitType(const Type& t) final { return t.defined() ? TypeMutator::VisitType(t) : t; } - Expr VisitExpr_(const FunctionNode* op) final { + Expr VisitExpr_(const FunctionNode* func_node) final { tvm::Array type_params; - for (const TypeVar& type_param : op->type_params) { + for (const TypeVar& type_param : func_node->type_params) { type_params.push_back(Fresh(type_param)); } tvm::Array params; - for (const Var& param : op->params) { + for (const Var& param : func_node->params) { params.push_back(Fresh(param)); } - return Function(params, VisitExpr(op->body), VisitType(op->ret_type), type_params, op->attrs); + return WithFields(GetRef(func_node), params, VisitExpr(func_node->body), + VisitType(func_node->ret_type), type_params); } Pattern VisitPattern(const Pattern& p) final { return PatternFunctor::VisitPattern(p); } diff --git a/src/relay/transforms/defunctionalization.cc b/src/relay/transforms/defunctionalization.cc index 5255a672a856..38e403a8d9b0 100644 --- a/src/relay/transforms/defunctionalization.cc +++ b/src/relay/transforms/defunctionalization.cc @@ -283,7 +283,7 @@ class DefuncMutator : public ExprMutator { auto apply_gv = GetApplyFunction(ft); auto body = this->VisitExpr(Bind(fn->body, free_var_bind_map)); - AddApplyCase(apply_gv, ft, c, Function(fn->params, body, fn->ret_type, fn->type_params), + AddApplyCase(apply_gv, ft, c, WithFields(GetRef(fn), fn->params, body), pattern_vars); return Call(c, call_args); @@ -380,7 +380,7 @@ class DefuncMutator : public ExprMutator { map.Set(f->type_params[i], type_args[i]); } // copy with typevars removed - auto copy = TypeSubst(Function(f->params, f->body, f->ret_type, {}), map); + auto copy = TypeSubst(WithFields(f, {}, {}, {}, /* erase type params */ Array()), map); return Downcast(copy); } @@ -410,7 +410,8 @@ class DefuncMutator : public ExprMutator { } auto bind = Downcast(Bind(f, var_bind_map)); - return Function(params, this->VisitExpr(bind->body), bind->ret_type, {}); + return WithFields(bind, params, this->VisitExpr(bind->body), bind->ret_type, + /* erase type params */ Array()); } }; diff --git a/src/relay/transforms/eta_expand.cc b/src/relay/transforms/eta_expand.cc index 4023c9dafef4..40b0a54ba38c 100644 --- a/src/relay/transforms/eta_expand.cc +++ b/src/relay/transforms/eta_expand.cc @@ -129,8 +129,7 @@ class EtaExpander : public ExprMutator { params.push_back(var); args.push_back(var); } - - return Function(args, Call(gvar, params), func->ret_type, func->type_params); + return WithFields(func, args, Call(gvar, params)); } else { return std::move(gvar); } diff --git a/src/relay/transforms/first_order_gradient.cc b/src/relay/transforms/first_order_gradient.cc index d695c6dc491d..f530d61e0d99 100644 --- a/src/relay/transforms/first_order_gradient.cc +++ b/src/relay/transforms/first_order_gradient.cc @@ -307,8 +307,9 @@ Pass FirstOrderGradient() { }); return Pair(res.forward, grad_tuple); }); - ad_mod->Update(pr.first, - Function(func->params, body, GradRetType(GetRef(func)), {})); + ad_mod->Update(pr.first, WithFields(GetRef(func), func->params, body, + GradRetType(GetRef(func)), + /* erase type params */ Array())); } return ad_mod; diff --git a/src/relay/transforms/higher_order_gradient.cc b/src/relay/transforms/higher_order_gradient.cc index 202275626d5d..1cf7cb86692c 100644 --- a/src/relay/transforms/higher_order_gradient.cc +++ b/src/relay/transforms/higher_order_gradient.cc @@ -341,28 +341,28 @@ struct ReverseAD : ExprMutator { GlobalVar gv(op->name_hint + "_grad"); (*ad_gvars)[orig_gv] = gv; Function orig_f = Downcast(DeDup(mod.value()->Lookup(orig_gv))); - std::vector params; + Array params; for (const auto& p : orig_f->params) { params.push_back(Downcast(VisitExpr(p))); } params.push_back(bp); - Expr body = VisitExpr(orig_f->body); - Function f(params, body, VisitType(orig_f->ret_type), orig_f->type_params, orig_f->attrs); + Function f = WithFields(orig_f, params, VisitExpr(orig_f->body), VisitType(orig_f->ret_type)); std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl; mod.value()->Add(gv, f); } return ad_gvars->at(orig_gv); } - Expr VisitExpr_(const FunctionNode* op) final { - std::vector params; - for (const auto& var : op->params) { + Expr VisitExpr_(const FunctionNode* func_node) final { + Array params; + for (const auto& var : func_node->params) { params.push_back(Downcast(VisitExpr(var))); } auto new_bp = Var("bp", bpt); params.push_back(new_bp); - return Function(params, ReverseAD(mod, new_bp, ad_vars, ad_gvars)(op->body), - VisitType(op->ret_type), op->type_params, op->attrs); + return WithFields(GetRef(func_node), params, + ReverseAD(mod, new_bp, ad_vars, ad_gvars)(func_node->body), + VisitType(func_node->ret_type)); } Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; } @@ -456,7 +456,8 @@ Expr Gradient(const Expr& re, const Optional& mod) { }; return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret)); }); - auto ret = Function(f->params, body, GradRetType(GetRef(f)), {}); + Function ret = WithFields(GetRef(f), f->params, body, GradRetType(GetRef(f)), + /* erase type params */ Array()); CheckFeature(ret, FeatureSet::All() - fGraph); return std::move(ret); } diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index f1492b9f1258..a6e26364bbc4 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -91,8 +91,7 @@ class Inliner : ExprMutator { } Function Inline(const Function& func) { - return Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); + return WithFields(func, func->params, VisitExpr(func->body)); } private: @@ -131,6 +130,8 @@ class Inliner : ExprMutator { const auto* fn = base_func.as(); ICHECK(fn) << "Expected to work on a Relay function."; + // There is an inconsistency here, the function itself gets shallow-copied but the body is not + // shallow-copied. auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 28d1aa5532bf..fc9922ca03ef 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -827,18 +827,18 @@ class PartialEvaluator : public ExprFunctor Expr VisitFuncDynamic(const Function& func, const Func& f, const Expr& self) { return store_.Extend([&]() { store_.Invalidate(); - return Function(func->params, LetList::With([&](LetList* ll) { - std::vector pv; - for (const auto& v : func->params) { - pv.push_back(NoStatic(v)); - } - tvm::Array type_args; - for (const auto& tp : func->type_params) { - type_args.push_back(tp); - } - return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; - }), - func->ret_type, func->type_params, func->attrs); + return WithFields( + func, func->params, LetList::With([&](LetList* ll) { + std::vector pv; + for (const auto& v : func->params) { + pv.push_back(NoStatic(v)); + } + tvm::Array type_args; + for (const auto& tp : func->type_params) { + type_args.push_back(tp); + } + return f(HasStatic(MkSFunc(f), self), pv, Attrs(), type_args, ll)->dynamic; + })); }); } diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index d1b9b563e932..bc1ed518d473 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -213,9 +213,8 @@ class Partitioner : public MixedModeMutator { auto glob_funcs = module_->functions; for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); + Function func = GetRef(fn); + func = WithFields(func, func->params, VisitExpr(func->body)); module_->Update(pair.first, func); module_ = transform::InferType()(module_); } @@ -429,7 +428,7 @@ IRModule RemoveDefaultAnnotations(IRModule module) { auto func = GetRef(fn); DefaultRemover remover; auto removed = PostOrderRewrite(func->body, &remover); - func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs); + func = WithFields(func, func->params, removed); module->Update(pair.first, func); module = relay::transform::InferType()(module); } @@ -482,10 +481,10 @@ IRModule FlattenTupleOutputs(IRModule module) { module.CopyOnWrite(); for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { - auto func = GetRef(fn); + Function func = GetRef(fn); TupleOutFlattener to_flattener; auto removed = PostOrderRewrite(func->body, &to_flattener); - func = Function(func->params, removed, func->ret_type, func->type_params, func->attrs); + func = WithFields(func, func->params, removed); module->Update(pair.first, func); module = relay::transform::InferType()(module); } @@ -527,12 +526,12 @@ class NameMangleExtFuncs : public MixedModeMutator { auto new_dict = func->attrs->dict; new_dict.Set(tvm::attr::kGlobalSymbol, String(relay::backend::SanitizeName(mangle_fn_(pair.first->name_hint)))); - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - DictAttrs(new_dict)); + func = WithFields(func, func->params, VisitExpr(func->body), func->ret_type, + func->type_params, DictAttrs(new_dict)); + new_module->Add(mangled_gvars_[pair.first->name_hint], func); } else { - func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, - func->attrs); + func = WithFields(func, func->params, VisitExpr(func->body)); new_module->Add(pair.first, func); } } diff --git a/src/relay/transforms/pass_utils.h b/src/relay/transforms/pass_utils.h index 317ac17f83c8..b14a93f02b55 100644 --- a/src/relay/transforms/pass_utils.h +++ b/src/relay/transforms/pass_utils.h @@ -106,7 +106,7 @@ bool IsDataDependent(const CallNode* call); */ inline Expr TransformF(const std::function& func, const Expr& e) { if (const FunctionNode* f = e.as()) { - return Function(f->params, func(f->body), f->ret_type, f->type_params, f->attrs); + return WithFields(GetRef(f), f->params, func(f->body)); } else { return func(e); } diff --git a/src/relay/transforms/simplify_fc_transpose.cc b/src/relay/transforms/simplify_fc_transpose.cc index b5090e7e6fe4..ad38ea6cb8df 100644 --- a/src/relay/transforms/simplify_fc_transpose.cc +++ b/src/relay/transforms/simplify_fc_transpose.cc @@ -128,12 +128,12 @@ Pass SimplifyFCTranspose(const Array& target_weights) { // Remove FreeVar warning auto f0 = Downcast(SimplifyFCTranspose(f, target_weights)); Array wt_params = FreeVars(f0); - auto f1 = Function(wt_params, f0->body, f0->ret_type, f0->type_params, f0->attrs); + auto f1 = WithFields(f0, wt_params); Array params = FreeVars(f1); for (const auto& var : wt_params) { params.push_back(var); } - return Function(params, f1->body, f1->ret_type, f1->type_params, f1->attrs); + return WithFields(f1, params); }; return CreateFunctionPass(pass_func, 4, "SimplifyFCTranspose", {"DeadCodeElimination"}); } diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index f6d5ac9cf8bb..a0841ec44fae 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -298,8 +298,8 @@ class Fill : ExprFunctor, private transform::Lexi PushBoundVar(f->params[i], GetFunctionParamVirtualDevice(f, i)); } EnterFunctionBody(); - ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type, - f->type_params, f->attrs); + ret = WithFields(GetRef(f), f->params, + GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body))); // We are done with this function. ExitFunctionBody(); for (size_t i = 0; i < f->params.size(); ++i) { diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index c5d17fbfbef7..6d8fe67847f6 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -272,8 +272,8 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, new_params.push_back(remap(v)); } new_params.push_back(k); - return Function(new_params, mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }), - answer, f->type_params, f->attrs); + return WithFields(f, new_params, + mut.VisitExpr(f->body, [&](const Expr& e) { return Call(k, {e}); }), answer); } Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { @@ -299,7 +299,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) { Function ret = ToCPS(f, m, cm, &var, answer); auto new_type_params = ret->type_params; new_type_params.push_back(answer); - return Function(ret->params, ret->body, ret->ret_type, new_type_params, ret->attrs); + return WithFields(ret, ret->params, ret->body, ret->ret_type, new_type_params); } Function ToCPS(const Function& f, const IRModule& m) { @@ -311,7 +311,7 @@ Function ToCPS(const Function& f, const IRModule& m) { Function UnCPS(const Function& f) { CheckFeature(f, FeatureSet::All() - fGraph); ICHECK_GT(f->params.size(), 0); - std::vector new_params; + Array new_params; for (const auto& p : f->params) { new_params.push_back(Var(p->name_hint(), p->checked_type())); } @@ -319,7 +319,7 @@ Function UnCPS(const Function& f) { new_params.pop_back(); ICHECK_EQ(cont_type->arg_types.size(), 1); auto new_ret_type = Type(cont_type->arg_types[0]); - std::vector new_type_params; + Array new_type_params; for (const auto& tp : f->type_params) { new_type_params.push_back(TypeVar(tp->name_hint, tp->kind)); } @@ -339,8 +339,7 @@ Function UnCPS(const Function& f) { type_args.push_back(tp); } type_args.push_back(new_ret_type); - return Function(new_params, Call(f, args, {}, type_args), new_ret_type, new_type_params, - f->attrs); + return WithFields(f, new_params, Call(f, args, {}, type_args), new_ret_type, new_type_params); } TVM_REGISTER_GLOBAL("relay._transform.to_cps")