From 498143799e42762c55b209a32ca907a0a5044cf4 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 13 Aug 2021 13:46:38 -0700 Subject: [PATCH 1/9] Add DictAttrs to IRModuleNode Move GetAttrs to be a member of DictAttrs Generalize WithAttrs to work with IRModule and move to attrs.h Change func->GetAttr to func->attrs.GetAttr --- include/tvm/ir/attrs.h | 108 ++++++++++++++++++ include/tvm/ir/function.h | 103 ----------------- include/tvm/ir/module.h | 2 + include/tvm/target/target.h | 3 +- src/driver/driver_api.cc | 4 +- src/ir/module.cc | 2 +- src/relay/analysis/context_analysis.cc | 4 +- src/relay/analysis/extract_fused_functions.cc | 2 +- src/relay/analysis/feature.cc | 2 +- src/relay/analysis/get_calibration_data.cc | 6 +- src/relay/backend/aot_executor_codegen.cc | 2 +- src/relay/backend/compile_engine.cc | 10 +- .../contrib/arm_compute_lib/codegen.cc | 2 +- src/relay/backend/contrib/bnns/codegen.cc | 4 +- .../contrib/codegen_json/codegen_json.h | 6 +- src/relay/backend/contrib/dnnl/codegen.cc | 2 +- src/relay/backend/contrib/ethosn/codegen.cc | 4 +- src/relay/backend/graph_executor_codegen.cc | 4 +- src/relay/backend/graph_plan_memory.cc | 2 +- src/relay/backend/interpreter.cc | 4 +- src/relay/backend/te_compiler.cc | 22 ++-- src/relay/backend/utils.h | 6 +- src/relay/backend/vm/compiler.cc | 8 +- src/relay/backend/vm/inline_primitives.cc | 6 +- src/relay/backend/vm/lambda_lift.cc | 8 +- src/relay/ir/transform.cc | 4 +- src/relay/transforms/annotate_target.cc | 4 +- src/relay/transforms/fold_constant.cc | 2 +- src/relay/transforms/fuse_ops.cc | 4 +- src/relay/transforms/inline.cc | 6 +- src/relay/transforms/memory_alloc.cc | 4 +- src/relay/transforms/partial_eval.cc | 2 +- src/relay/transforms/partition_graph.cc | 4 +- src/relay/transforms/to_a_normal_form.cc | 4 +- .../transforms/to_basic_block_normal_form.cc | 2 +- src/relay/transforms/to_cps.cc | 2 +- src/target/build_common.h | 6 +- src/target/llvm/codegen_cpu.cc | 2 +- src/target/llvm/codegen_hexagon.cc | 2 +- src/target/llvm/codegen_llvm.cc | 4 +- src/target/llvm/llvm_module.cc | 4 +- src/target/opt/build_cuda_on.cc | 2 +- src/target/source/codegen_aocl.cc | 2 +- src/target/source/codegen_c.cc | 4 +- src/target/source/codegen_c_host.cc | 4 +- src/target/source/codegen_metal.cc | 6 +- src/target/source/codegen_opencl.cc | 2 +- src/target/source/codegen_vhls.cc | 4 +- src/target/spirv/build_vulkan.cc | 4 +- src/target/spirv/codegen_spirv.cc | 2 +- src/target/stackvm/codegen_stackvm.cc | 4 +- src/tir/analysis/verify_memory.cc | 4 +- src/tir/transforms/ir_utils.cc | 2 +- src/tir/transforms/lower_custom_datatypes.cc | 2 +- src/tir/transforms/lower_intrin.cc | 3 +- src/tir/transforms/lower_thread_allreduce.cc | 2 +- src/tir/transforms/lower_warp_memory.cc | 2 +- src/tir/transforms/make_packed_api.cc | 6 +- src/tir/transforms/make_unpacked_api.cc | 6 +- src/tir/transforms/remap_thread_axis.cc | 2 +- src/tir/transforms/split_host_device.cc | 4 +- src/tir/transforms/storage_flatten.cc | 2 +- 62 files changed, 230 insertions(+), 221 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index da7bc12619bd..3aaef82e89b9 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -214,6 +214,7 @@ class DictAttrsNode : public BaseAttrsNode { void VisitNonDefaultAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; Array ListFieldInfo() const final; + // type info static constexpr const char* _type_key = "DictAttrs"; TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); @@ -232,6 +233,72 @@ class DictAttrs : public Attrs { */ TVM_DLL explicit DictAttrs(Map dict); + // Utils for accessing attributes + // This needs to be on DictAttrs, not DictAttrsNode because we return the default + // value if DictAttrsNode is not defined. + /*! + * \brief Get a function attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const BaseFunc& f) { + * auto value = f->attrs.GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + static_assert(std::is_base_of::value, + "Can only call GetAttr with ObjectRef types."); + if (!defined()) return default_value; + const DictAttrsNode* node = this->as(); + + auto it = node->dict.find(attr_key); + if (it != node->dict.end()) { + return Downcast>((*it).second); + } else { + return default_value; + } + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + /*! + * \brief Check whether the function has an non-zero integer attr. + * + * This function can be used to check whether an optional + * attribute mark(e.g. inline) exists. + * + * \param attr_key The key to the attribute. + * \return The check result. + * + * \code + * + * void HasNonzeroAttrExample(const BaseFunc& f) { + * if (f->attrs.HasNonzeroAttr(attr::kInline)) { + * // inline the function. + * } + * } + * + * \endcode + */ + bool HasNonzeroAttr(const std::string& attr_key) const { + return GetAttr(attr_key, 0) != 0; + } + TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; @@ -249,6 +316,47 @@ inline TAttrs AttrsWithDefaultValues() { return TAttrs(n); } +/*! + * \brief Copy the function or module, but overrides + * the attribute value key with the value. + * + * \param input The thing to annotate (BaseFunc or IRModule) + * \param attr_key The attribute key. + * \param attr_value The value attribute value. + * + * \tparam TFunc The corresponding function or module type. + * + * \returns The new function or module with updated attributes. + * + * \note This function performs copy on write optimization for func and module. + * If we move a uniquely referenced func or module into WithAttr, + * then no additional copy will be performed. + * + * This is also why we make it as a function instead of a member function + * and why we pass by value in the first argument. + * + * \code + * + * // Recommended way to trigger copy on write + * func = WithAttr(std::move(func), "key1", value1); + * func = WithAttr(std::move(func), "key2", value2); + * + * \endcode + */ +template +inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) { + using TNode = typename TFunc::ContainerType; + static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); + TNode* node = func.CopyOnWrite(); + if (node->attrs.defined()) { + node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); + } else { + Map dict = {{attr_key, attr_value}}; + node->attrs = DictAttrs(dict); + } + return func; +} + // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 09c074cb71bd..020b3de77ab3 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -79,67 +79,6 @@ class BaseFuncNode : public RelayExprNode { /*! \brief Additional attributes storing the meta-data */ DictAttrs attrs; - /*! - * \brief Get a function attribute. - * - * \param attr_key The attribute key. - * \param default_value The default value if the key does not exist, defaults to nullptr. - * - * \return The result - * - * \tparam TOBjectRef the expected object type. - * \throw Error if the key exists but the value does not match TObjectRef - * - * \code - * - * void GetAttrExample(const BaseFunc& f) { - * auto value = f->GetAttr("AttrKey", 0); - * } - * - * \endcode - */ - template - Optional GetAttr( - const std::string& attr_key, - Optional default_value = Optional(nullptr)) const { - static_assert(std::is_base_of::value, - "Can only call GetAttr with ObjectRef types."); - if (!attrs.defined()) return default_value; - auto it = attrs->dict.find(attr_key); - if (it != attrs->dict.end()) { - return Downcast>((*it).second); - } else { - return default_value; - } - } - // variant that uses TObjectRef to enable implicit conversion to default value. - template - Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { - return GetAttr(attr_key, Optional(default_value)); - } - /*! - * \brief Check whether the function has an non-zero integer attr. - * - * This function can be used to check whether an optional - * attribute mark(e.g. inline) exists. - * - * \param attr_key The key to the attribute. - * \return The check result. - * - * \code - * - * void HasNonzeroAttrExample(const BaseFunc& f) { - * if (f->HasNonzeroAttr(attr::kInline)) { - * // inline the function. - * } - * } - * - * \endcode - */ - bool HasNonzeroAttr(const std::string& attr_key) const { - return GetAttr(attr_key, 0) != 0; - } - static constexpr const char* _type_key = "BaseFunc"; static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode); @@ -154,48 +93,6 @@ class BaseFunc : public RelayExpr { TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); }; -/*! - * \brief Create a new function that copies func, but overrides - * the attribute value key with the value. - * - * \param func The input function. - * \param attr_key The attribute key. - * \param attr_value The value attribute value. - * - * \tparam TFunc The corresponding function type. - * - * \returns The new function with updated attributes. - * - * \note This function performs copy on write optimization for func. - * If we move a uniquely referenced func into WithAttr, - * then no additional copy will be performed. - * - * This is also why we make it as a function instead of a member function - * and why we pass by value in the first argument. - * - * \code - * - * // Recommended way to trigger copy on write - * func = WithAttr(std::move(func), "key1", value1); - * func = WithAttr(std::move(func), "key2", value2); - * - * \endcode - */ -template ::value>::type> -inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) { - using TNode = typename TFunc::ContainerType; - static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); - TNode* node = func.CopyOnWrite(); - if (node->attrs.defined()) { - node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); - } else { - Map dict = {{attr_key, attr_value}}; - node->attrs = DictAttrs(dict); - } - return func; -} - /*! * \brief Generic attribute names that can be attached to any function. * diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 638f132e3179..c88f438c7cd8 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -58,6 +58,8 @@ class IRModuleNode : public Object { Map type_definitions; /*! \brief The source map for the module. */ parser::SourceMap source_map; + /* \brief Additional attributes storing meta-data about the module. */ + DictAttrs attrs; IRModuleNode() : source_map() {} diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 9c1fe55749e4..614ff939c8ab 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -54,7 +54,7 @@ class TargetNode : public Object { /*! \brief Keys for this target */ Array keys; /*! \brief Collection of attributes */ - Map attrs; + Map attrs; // TODO(@electriclilies): Unify with DictAttrs on IRModule /*! * \brief The raw string representation of the target * \return the full device string to pass to codegen::Build @@ -101,6 +101,7 @@ class TargetNode : public Object { * \param default_value The value returned if the key is not present * \return An optional, NullOpt if not found, otherwise the value found */ + // TODO(@electriclilies): Remove once we have removed the target attrs template Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { return GetAttr(attr_key, Optional(default_value)); diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index d6af9936ca40..9c335d95b156 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -401,7 +401,7 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target auto host_pass_list = { Filter([](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != + return f->attrs.GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; }), BindTarget(target_host), @@ -418,7 +418,7 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target // device pipeline auto device_pass_list = { Filter([](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + return f->attrs.GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch; }), BindTarget(target), diff --git a/src/ir/module.cc b/src/ir/module.cc index 7990b281fb04..a62656aa69fb 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -356,7 +356,7 @@ IRModule IRModule::FromExpr(const RelayExpr& expr, if (auto* func_node = expr.as()) { func = GetRef(func_node); - if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { + if (auto opt = func->attrs.GetAttr(tvm::attr::kGlobalSymbol)) { gv_name = opt.value(); } diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc index 35813f67d094..cb785c138912 100644 --- a/src/relay/analysis/context_analysis.cc +++ b/src/relay/analysis/context_analysis.cc @@ -329,7 +329,7 @@ class ContextAnalyzer : public MixedModeVisitor { auto func = GetRef(fn); // No need to step into fused primitive functions as they are handled as // a whole. - if (fn->HasNonzeroAttr(attr::kPrimitive)) { + if (fn->attrs.HasNonzeroAttr(attr::kPrimitive)) { return; } @@ -432,7 +432,7 @@ class ContextAnalyzer : public MixedModeVisitor { } // Check if a function is a closure. - bool IsClosure(const Function& func) { return func->GetAttr(attr::kClosure, 0) != 0; } + bool IsClosure(const Function& func) { return func->attrs.GetAttr(attr::kClosure, 0) != 0; } // Check if a function is a currying function. bool IsCurrying(const Function& func) { diff --git a/src/relay/analysis/extract_fused_functions.cc b/src/relay/analysis/extract_fused_functions.cc index e76b54e2d0b7..9ea2ba9828f1 100644 --- a/src/relay/analysis/extract_fused_functions.cc +++ b/src/relay/analysis/extract_fused_functions.cc @@ -53,7 +53,7 @@ class FusedFunctionExtractorWrapper : private ExprVisitor { Map functions; void VisitExpr_(const FunctionNode* n) final { - if (n->HasNonzeroAttr(attr::kPrimitive)) { + if (n->attrs.HasNonzeroAttr(attr::kPrimitive)) { // Add function to functions, keyed by function hash string Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs); size_t hash_ = tvm::StructuralHash()(func); diff --git a/src/relay/analysis/feature.cc b/src/relay/analysis/feature.cc index f72b4e105749..72964865fd1f 100644 --- a/src/relay/analysis/feature.cc +++ b/src/relay/analysis/feature.cc @@ -60,7 +60,7 @@ FeatureSet DetectFeature(const Expr& expr) { DETECT_DEFAULT_CONSTRUCT(Tuple) DETECT_DEFAULT_CONSTRUCT(TupleGetItem) DETECT_CONSTRUCT(Function, { - if (!op->HasNonzeroAttr(attr::kPrimitive)) { + if (!op->attrs.HasNonzeroAttr(attr::kPrimitive)) { ExprVisitor::VisitExpr_(op); } }) diff --git a/src/relay/analysis/get_calibration_data.cc b/src/relay/analysis/get_calibration_data.cc index 12bab1e38ddd..80460f9c52c5 100644 --- a/src/relay/analysis/get_calibration_data.cc +++ b/src/relay/analysis/get_calibration_data.cc @@ -55,7 +55,7 @@ class Collector : public ExprRewriter { ICHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined"; // we only handle functions with Compiler attribute set auto func = Downcast(module_->Lookup(var)); - if (func->GetAttr(attr::kCompiler)) { + if (func->attrs.GetAttr(attr::kCompiler)) { // collect all the inputs and outputs for (const auto& it : call->args) new_outputs_.push_back(it); new_outputs_.push_back(post); @@ -110,7 +110,7 @@ IRModule GetCalibrateModule(IRModule module) { for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); - if (func->GetAttr(attr::kCompiler)) { + if (func->attrs.GetAttr(attr::kCompiler)) { // we need to inline the functions in order to run grpah runtime func = WithAttr(std::move(func), attr::kInline, tvm::Integer(1)); // reset the compiler attribute to null for llvm execution @@ -145,7 +145,7 @@ class OutputMapper : public ExprRewriter { << "Repeated function call " << var << " is not supported."; auto func = Downcast(module_->Lookup(var)); // we only handle functions with Compiler attribute set - if (func->GetAttr(attr::kCompiler)) { + if (func->attrs.GetAttr(attr::kCompiler)) { Array info; // the first value is the offset info.push_back(Integer(*offset_)); diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 54a10add2f07..a9b381c0461f 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -440,7 +440,7 @@ class AOTExecutorCodegen : public ExprVisitor { void VisitExpr_(const GlobalVarNode* op) override { throw std::runtime_error(""); } void VisitExpr_(const IfNode* op) override { throw std::invalid_argument("if not supported"); } void VisitExpr_(const FunctionNode* op) override { - ICHECK(op->GetAttr(attr::kCompiler).defined()) + ICHECK(op->attrs.GetAttr(attr::kCompiler).defined()) << "FunctionNode only supported by custom codegen"; } void VisitExpr_(const RefCreateNode* op) override { diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 6142e8323dea..beea25efd940 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -89,13 +89,13 @@ class CompileEngineImpl : public CompileEngineNode { auto src_func = it.first->source_func; ICHECK(src_func.defined()); - if (src_func->GetAttr(attr::kCompiler).defined()) { - auto code_gen = src_func->GetAttr(attr::kCompiler); + if (src_func->attrs.GetAttr(attr::kCompiler).defined()) { + auto code_gen = src_func->attrs.GetAttr(attr::kCompiler); ICHECK(code_gen.defined()) << "No external codegen is set"; std::string code_gen_name = code_gen.value(); cached_ext_funcs.push_back(it.first); - auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); + auto symbol_name = src_func->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false) << "\n" << "Functions with external codegen must have the " @@ -186,9 +186,9 @@ class CompileEngineImpl : public CompileEngineNode { // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. - if (key->source_func->GetAttr(attr::kCompiler).defined()) { + if (key->source_func->attrs.GetAttr(attr::kCompiler).defined()) { auto ir_module = IRModule(); - const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); + const auto name_node = key->source_func->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "External function has not been attached a name yet."; auto func_name = std::string(name_node.value()); auto target = Target("ext_dev"); diff --git a/src/relay/backend/contrib/arm_compute_lib/codegen.cc b/src/relay/backend/contrib/arm_compute_lib/codegen.cc index 8098c8d51274..0270b01ab8ff 100644 --- a/src/relay/backend/contrib/arm_compute_lib/codegen.cc +++ b/src/relay/backend/contrib/arm_compute_lib/codegen.cc @@ -87,7 +87,7 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { << cn->op->GetTypeKey(); } auto fn = cn->op.as(); - auto comp = fn->GetAttr(attr::kComposite); + auto comp = fn->attrs.GetAttr(attr::kComposite); ICHECK(comp.defined()) << "Arm Compute Library JSON runtime only supports composite functions."; const std::string name = comp.value(); std::shared_ptr json_node; diff --git a/src/relay/backend/contrib/bnns/codegen.cc b/src/relay/backend/contrib/bnns/codegen.cc index 72c32fb5b19e..6464d015ebc1 100644 --- a/src/relay/backend/contrib/bnns/codegen.cc +++ b/src/relay/backend/contrib/bnns/codegen.cc @@ -80,7 +80,7 @@ class BNNSJSONSerializer : public backend::contrib::JSONSerializer { if (const auto* op_node = cn->op.as()) { name = op_node->name; } else if (const auto* fn = cn->op.as()) { - auto comp = fn->GetAttr(attr::kComposite); + auto comp = fn->attrs.GetAttr(attr::kComposite); ICHECK(comp.defined()) << "BNNS JSON runtime only supports composite functions."; name = comp.value(); @@ -176,7 +176,7 @@ struct BNNSConstantUpdater : public ConstantUpdater { private: bool isBNNSSpecificCompositeFunc(const FunctionNode* op) { - auto comp = op->GetAttr(attr::kComposite); + auto comp = op->attrs.GetAttr(attr::kComposite); if (!comp) return false; auto comp_name = comp.value(); diff --git a/src/relay/backend/contrib/codegen_json/codegen_json.h b/src/relay/backend/contrib/codegen_json/codegen_json.h index 4966f3f01c7d..0366f8d2b838 100644 --- a/src/relay/backend/contrib/codegen_json/codegen_json.h +++ b/src/relay/backend/contrib/codegen_json/codegen_json.h @@ -225,7 +225,7 @@ class JSONSerializer : public MemoizedExprTranslatorattrs.get(); extractor.Extract(const_cast(call_attr)); } else if (const auto* fn = cn->op.as()) { - auto pattern = fn->GetAttr(attr::kPartitionedFromPattern); + auto pattern = fn->attrs.GetAttr(attr::kPartitionedFromPattern); ICHECK(pattern.defined()); std::vector values; values.push_back(pattern.value()); @@ -267,7 +267,7 @@ class JSONSerializer : public MemoizedExprTranslatorop.as()) { name = op_node->name; } else if (const auto* fn = cn->op.as()) { - auto comp = fn->GetAttr(attr::kComposite); + auto comp = fn->attrs.GetAttr(attr::kComposite); ICHECK(comp.defined()) << "JSON runtime only supports composite functions."; name = comp.value(); } else { @@ -298,7 +298,7 @@ class JSONSerializer : public MemoizedExprTranslator VisitExpr_(const FunctionNode* fn) { - ICHECK(fn->GetAttr(attr::kComposite).defined()) + ICHECK(fn->attrs.GetAttr(attr::kComposite).defined()) << "JSON runtime only supports composite functions"; // FunctionNode should be handled by the caller. return {}; diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index e96255e976e9..cfde6550a431 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -440,7 +440,7 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { if (const auto* op_node = cn->op.as()) { name = op_node->name; } else if (const auto* fn = cn->op.as()) { - auto comp = fn->GetAttr(attr::kComposite); + auto comp = fn->attrs.GetAttr(attr::kComposite); ICHECK(comp.defined()) << "DNNL JSON runtime only supports composite functions."; name = comp.value(); diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index 97b308e51e18..f3343c5e2648 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -53,7 +53,7 @@ bool IsEthosnFunc(const Call& call, const std::string& op_name) { if (call->op->IsInstance()) { Function func = Downcast(call->op); ICHECK(func.defined()); - auto name_node = func->GetAttr(attr::kComposite); + auto name_node = func->attrs.GetAttr(attr::kComposite); return name_node.value() == op_name; } return false; @@ -521,7 +521,7 @@ runtime::Module EthosnCompiler::CreateRuntimeModule(const ObjectRef& ref) { if (ref->IsInstance()) { IRModule mod; Function func = Downcast(ref); - auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); + auto name_node = func->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "Failed to retrieved external symbol."; GlobalVar gvar = GlobalVar(name_node.value()); mod->Add(gvar, func); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index cc54a52be200..bbd80eccce39 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -226,7 +226,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorGetAttr(attr::kCompiler).defined()) { + if (func->attrs.GetAttr(attr::kCompiler).defined()) { UpdateConstants(func, ¶ms_); } @@ -473,7 +473,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator VisitExpr_(const FunctionNode* op) override { - ICHECK(op->GetAttr(attr::kCompiler).defined()) + ICHECK(op->attrs.GetAttr(attr::kCompiler).defined()) << "Only functions supported by custom codegen"; return {}; } diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 93c823d8a007..a4fad4bfe4f5 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -345,7 +345,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { */ static bool IsReshape(const CallNode* call) { if (const auto* fn = call->op.as()) { - return fn->HasNonzeroAttr(attr::kReshapeOnly); + return fn->attrs.HasNonzeroAttr(attr::kReshapeOnly); } if (call->attrs.defined()) { diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 6ebb17e93eca..3970d6c041d9 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -480,7 +480,7 @@ class Interpreter : public ExprFunctor, bool is_dyn = IsDynamic(ret_type); if (is_dyn) { - ICHECK(func->HasNonzeroAttr(attr::kPrimitive)); + ICHECK(func->attrs.HasNonzeroAttr(attr::kPrimitive)); out_shapes = ComputeDynamicShape(func, args); } @@ -519,7 +519,7 @@ class Interpreter : public ExprFunctor, ObjectRef Invoke(const InterpreterClosure& closure, const tvm::Array& args, const Var& bind = Var()) { // Get a reference to the function inside the closure. - if (closure->func->HasNonzeroAttr(attr::kPrimitive)) { + if (closure->func->attrs.HasNonzeroAttr(attr::kPrimitive)) { return InvokePrimitiveOp(closure->func, args); } auto func = closure->func; diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 7840960ec268..07fbc6d68265 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -111,12 +111,12 @@ class TECompilerImpl : public TECompilerNode { for (const auto& it : cache_) { auto src_func = it.first->source_func; ICHECK(src_func.defined()); - if (src_func->GetAttr(attr::kCompiler).defined()) { - auto code_gen = src_func->GetAttr(attr::kCompiler); + if (src_func->attrs.GetAttr(attr::kCompiler).defined()) { + auto code_gen = src_func->attrs.GetAttr(attr::kCompiler); std::string code_gen_name = code_gen.value(); cached_ext_funcs.push_back(it.first); - auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); + auto symbol_name = src_func->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); @@ -187,9 +187,9 @@ class TECompilerImpl : public TECompilerNode { // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. - if (key->source_func->GetAttr(attr::kCompiler).defined()) { + if (key->source_func->attrs.GetAttr(attr::kCompiler).defined()) { auto ir_module = IRModule(); - const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); + const auto name_node = key->source_func->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "External function has not been attached a name yet."; auto func_name = GetUniqueName(name_node.value(), &name_map_); auto target = Target("ext_dev"); @@ -325,7 +325,7 @@ class LowerTensorExpr : public ExprMutator { return ExprMutator::VisitExpr_(call); } - if (!func->HasNonzeroAttr(attr::kPrimitive)) { + if (!func->attrs.HasNonzeroAttr(attr::kPrimitive)) { // Provide a callback hook which allows one-level up code generators to // act when we process a function. this->process_fn(func); @@ -340,7 +340,7 @@ class LowerTensorExpr : public ExprMutator { Target target; - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->attrs.GetAttr(attr::kCompiler).defined()) { target = Target("ext_dev"); CCacheKey key = CCacheKey(func, target); CachedFunc ext_func = compiler_->Lower(key, module_name_); @@ -390,7 +390,7 @@ class LowerTensorExpr : public ExprMutator { this->process_fn(func_with_metadata); auto tir_call_attrs = make_object(); - if (func->HasNonzeroAttr(attr::kReshapeOnly)) { + if (func->attrs.HasNonzeroAttr(attr::kReshapeOnly)) { tir_call_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); } @@ -604,13 +604,13 @@ void UpdateFunctionMetadata(Function relay_func, Map relay_primfuncs; Optional> prim_fns = - relay_func->GetAttr>("prim_funcs"); + relay_func->attrs.GetAttr>("prim_funcs"); CHECK(prim_fns) << "primitive functions not set on Relay function by TECompiler."; - Optional prim_fn_var = relay_func->GetAttr("prim_fn_var"); + Optional prim_fn_var = relay_func->attrs.GetAttr("prim_fn_var"); CHECK(prim_fn_var) << "prim_fn_var must be set on Relay functions by TECompiler."; - Optional relay_target = relay_func->GetAttr("target"); + Optional relay_target = relay_func->attrs.GetAttr("target"); CHECK(relay_target) << "target must be set on Relay functions by the TECompiler."; for (const auto& kv : prim_fns.value()) { diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index a0c7a5aad26d..615d212e0169 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -173,10 +173,10 @@ struct ConstantUpdater : public ExprVisitor { */ inline void UpdateConstants(Function func, std::unordered_map* params) { - auto codegen = func->GetAttr(attr::kCompiler); + auto codegen = func->attrs.GetAttr(attr::kCompiler); ICHECK(codegen.defined()) << "No external codegen is set"; std::string codegen_name = codegen.value(); - const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); + const auto name_node = func->attrs.GetAttr(tvm::attr::kGlobalSymbol); std::string symbol = std::string(name_node.value()); std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater"; // Get the constant updater for the external codegen @@ -392,7 +392,7 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth, * \return An external symbol. */ inline std::string GetExtSymbol(const Function& func) { - const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); + const auto name_node = func->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "Fail to retrieve external symbol."; return std::string(name_node.value()); } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b3eab91d202c..62c45ff0168f 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -501,7 +501,7 @@ class VMFunctionCompiler : ExprFunctor { void EmitInvokeTVMOp(const Function& func, const Expr& inputs, const Expr& outputs) { std::vector argument_registers; - ICHECK(func->GetAttr(attr::kPrimitive, 0) != 0) + ICHECK(func->attrs.GetAttr(attr::kPrimitive, 0) != 0) << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; auto input_tuple = inputs.as(); @@ -526,7 +526,7 @@ class VMFunctionCompiler : ExprFunctor { Target target; - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->attrs.GetAttr(attr::kCompiler).defined()) { target = Target("ext_dev"); } else { // Next generate the invoke instruction. @@ -553,7 +553,7 @@ class VMFunctionCompiler : ExprFunctor { auto cfunc = context_->compiler->Lower(key, mangle_fn); auto op_index = -1; - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->attrs.GetAttr(attr::kCompiler).defined()) { op_index = context_->cached_funcs.size(); context_->cached_funcs.push_back(cfunc); } else { @@ -765,7 +765,7 @@ class VMFunctionCompiler : ExprFunctor { } void VisitExpr_(const FunctionNode* func_node) { - if (!func_node->HasNonzeroAttr(attr::kPrimitive)) { + if (!func_node->attrs.HasNonzeroAttr(attr::kPrimitive)) { LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl << "Program: " << AsText(GetRef(func_node), false) << std::endl << "AST: " << GetRef(func_node); diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 05fb2a120620..48976384c50a 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -97,7 +97,7 @@ struct PrimitiveInliner : ExprMutator { } if (auto func = op.as()) { - if (func->HasNonzeroAttr(attr::kPrimitive)) { + if (func->attrs.HasNonzeroAttr(attr::kPrimitive)) { tvm::Array call_args; for (auto arg : call->args) { auto new_arg = VisitExpr(arg); @@ -120,7 +120,7 @@ struct PrimitiveInliner : ExprMutator { } Expr VisitExpr_(const FunctionNode* func) { - if (func->HasNonzeroAttr(attr::kPrimitive)) { + if (func->attrs.HasNonzeroAttr(attr::kPrimitive)) { return GetRef(func); } else { return ExprMutator::VisitExpr_(func); @@ -133,7 +133,7 @@ struct PrimitiveInliner : ExprMutator { auto global = pair.first; auto base_func = pair.second; if (auto* n = base_func.as()) { - if (n->GetAttr(attr::kCompiler).defined()) continue; + if (n->attrs.GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); DLOG(INFO) << "Before inlining primitives: " << global << std::endl << AsText(func, false); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index c768a2c300ec..bce1f4437b08 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -44,7 +44,7 @@ inline std::string GenerateName(const Function& func) { return std::string("lifted_name") + std::to_string(hash); } -bool IsClosure(const Function& func) { return func->GetAttr(attr::kClosure, 0) != 0; } +bool IsClosure(const Function& func) { return func->attrs.GetAttr(attr::kClosure, 0) != 0; } Function MarkClosure(Function func) { return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1)); @@ -64,7 +64,7 @@ class LambdaLifter : public ExprMutator { auto pre_visit = [this](const LetNode* op) { bool is_lambda = false; if (auto func = op->value.as()) { - if (!func->HasNonzeroAttr(attr::kPrimitive)) { + if (!func->attrs.HasNonzeroAttr(attr::kPrimitive)) { is_lambda = true; this->letrec_.push_back(op->var); } @@ -104,7 +104,7 @@ class LambdaLifter : public ExprMutator { auto func = GetRef(func_node); // We should not transform primitive functions. - if (func->HasNonzeroAttr(attr::kPrimitive)) { + if (func->attrs.HasNonzeroAttr(attr::kPrimitive)) { return std::move(func); } @@ -224,7 +224,7 @@ class LambdaLifter : public ExprMutator { auto glob_funcs = module_->functions; for (auto pair : glob_funcs) { if (auto* n = pair.second.as()) { - if (n->GetAttr(attr::kCompiler).defined()) continue; + if (n->attrs.GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, func->attrs); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 4a7974cae5ae..a7f0959e169d 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -163,8 +163,8 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) } bool FunctionPassNode::SkipFunction(const Function& func) const { - return (func->GetAttr(attr::kCompiler).defined()) || - func->GetAttr(attr::kSkipOptimization, 0) != 0; + return (func->attrs.GetAttr(attr::kCompiler).defined()) || + func->attrs.GetAttr(attr::kSkipOptimization, 0) != 0; } Pass CreateFunctionPass( diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index b12e25a425b6..19c96a161483 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -227,7 +227,7 @@ class AnnotateTargetRewriter : public ExprRewriter { // if it is in the target list. Function func = Downcast(pre->op); ICHECK(func.defined()); - if (auto comp_name = func->GetAttr(attr::kComposite)) { + if (auto comp_name = func->attrs.GetAttr(attr::kComposite)) { std::string comp_name_str = comp_name.value(); size_t i = comp_name_str.find('.'); if (i != std::string::npos) { @@ -288,7 +288,7 @@ class AnnotateTargetRewriter : public ExprRewriter { Function func; Expr new_body; // don't step into composite functions - if (fn->GetAttr(attr::kComposite).defined()) { + if (fn->attrs.GetAttr(attr::kComposite).defined()) { func = GetRef(fn); new_body = func->body; } else { diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 57603035b848..2a3def468ab8 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -118,7 +118,7 @@ class ConstantFolder : public MixedModeMutator { bool inside_primitive = false; Expr VisitExpr_(const FunctionNode* op) final { - if (op->HasNonzeroAttr(attr::kPrimitive)) { + if (op->attrs.HasNonzeroAttr(attr::kPrimitive)) { ICHECK_EQ(inside_primitive, false); inside_primitive = true; auto ret = ExprMutator::VisitExpr_(op); diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index f1f7a95e33e8..1542e6d68d86 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -199,7 +199,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // Post order tree void VisitExpr_(const FunctionNode* op) final { // Skip the function that should be handled by external codegen. - if (op->GetAttr(attr::kCompiler).defined()) return; + if (op->attrs.GetAttr(attr::kCompiler).defined()) return; for (auto param : op->params) { this->Update(param, nullptr, kOpaque); @@ -856,7 +856,7 @@ class FuseMutator : private MixedModeMutator { // Skip primitive function. Expr VisitExpr_(const FunctionNode* fn_node) { - if (fn_node->HasNonzeroAttr(attr::kPrimitive)) { + if (fn_node->attrs.HasNonzeroAttr(attr::kPrimitive)) { return GetRef(fn_node); } else { return ExprMutator::VisitExpr_(fn_node); diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index 6e6505b28dc6..a8ce7bd22bbf 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -99,7 +99,7 @@ class Inliner : ExprMutator { if (!func->body.defined()) return false; // The function must be annotated with the inline attribute. - if (!func->HasNonzeroAttr(attr::kInline)) return false; + if (!func->attrs.HasNonzeroAttr(attr::kInline)) return false; // The function is not abled to be inlined if any callee under the CallGraph // of this function cannot be inlined. @@ -122,7 +122,7 @@ class Inliner : ExprMutator { auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. - if (!func->GetAttr(attr::kCompiler).defined()) { + if (!func->attrs.GetAttr(attr::kCompiler).defined()) { ICHECK_EQ(func->params.size(), args.size()) << "Mismatch found in the number of parameters and call args"; // Bind the parameters with call args. @@ -188,7 +188,7 @@ IRModule Inline(const IRModule& module) { auto base_func = cg->GetGlobalFunction(cgn->GetGlobalVar()); if (const auto* fn = base_func.as()) { auto func = GetRef(fn); - if (func->HasNonzeroAttr(attr::kInline)) { + if (func->attrs.HasNonzeroAttr(attr::kInline)) { ICHECK_EQ(cgn->GetRefCount(), 0U) << cgn->GetNameHint() << " is marked as inline but not inlined."; cgn->CleanCallGraphEntries(); diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 657e2c392455..7717c741951a 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -71,7 +71,7 @@ inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dt // Check if the primitive function contains only reshape ops. bool IsReshapeOnly(const Expr& expr) { if (const FunctionNode* func = expr.as()) { - return func->HasNonzeroAttr(attr::kReshapeOnly); + return func->attrs.HasNonzeroAttr(attr::kReshapeOnly); } if (const CallNode* call = expr.as()) { if (call->attrs.defined()) { @@ -199,7 +199,7 @@ class DialectRewriter : public ExprMutator { // Check if a call invokes a primitive function. bool IsPrimitive(const CallNode* call) const { if (const auto* fn = call->op.as()) { - return fn->HasNonzeroAttr(attr::kPrimitive); + return fn->attrs.HasNonzeroAttr(attr::kPrimitive); } return false; } diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index 9572faf08714..fb49d69855f0 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -776,7 +776,7 @@ class PartialEvaluator : public ExprFunctor Func VisitFuncStatic(const Function& func, const Expr& var) { ICHECK(IsAtomic(var)); - if (func->HasNonzeroAttr(attr::kPrimitive)) { + if (func->attrs.HasNonzeroAttr(attr::kPrimitive)) { return ConstEvaluateFunc(func); } std::vector > free_vars; diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index b48fbe44bd11..d67da140a803 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -499,7 +499,7 @@ class NameMangleExtFuncs : public MixedModeMutator { for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->attrs.GetAttr(attr::kCompiler).defined()) { auto fn_name_mangled = mangle_fn_(pair.first->name_hint); GlobalVar gvar = GlobalVar(fn_name_mangled); mangled_gvars_[pair.first->name_hint] = gvar; @@ -514,7 +514,7 @@ class NameMangleExtFuncs : public MixedModeMutator { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); - if (func->GetAttr(attr::kCompiler).defined()) { + if (func->attrs.GetAttr(attr::kCompiler).defined()) { auto new_dict = func->attrs->dict; new_dict.Set(tvm::attr::kGlobalSymbol, String(mangle_fn_(pair.first->name_hint))); func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 91e8d90c1232..189165131b8e 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -200,7 +200,7 @@ Expr Fill::VisitExpr_(const IfNode* i, const Var& v) { Expr Fill::VisitExpr_(const FunctionNode* f, const Var& v) { Expr e = GetRef(f); Expr ret; - if (f->HasNonzeroAttr(attr::kPrimitive)) { + if (f->attrs.HasNonzeroAttr(attr::kPrimitive)) { ret = e; } else { ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type, @@ -260,7 +260,7 @@ IRModule ToANormalForm(const IRModule& m) { for (const auto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0); if (const auto* n = it.second.as()) { - if (n->GetAttr(attr::kCompiler).defined()) continue; + if (n->attrs.GetAttr(attr::kCompiler).defined()) continue; } Expr ret = TransformF([&](const Expr& e) { return transform::ToANormalForm(e); }, it.second); ICHECK_EQ(FreeVars(ret).size(), 0) diff --git a/src/relay/transforms/to_basic_block_normal_form.cc b/src/relay/transforms/to_basic_block_normal_form.cc index 79157bba1918..fbab6a91b824 100644 --- a/src/relay/transforms/to_basic_block_normal_form.cc +++ b/src/relay/transforms/to_basic_block_normal_form.cc @@ -56,7 +56,7 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) { for (const auto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables"; if (const auto* n = it.second.as()) { - if (n->GetAttr(attr::kCompiler).defined()) continue; + if (n->attrs.GetAttr(attr::kCompiler).defined()) continue; } Expr ret = TransformF([&](const Expr& e) { return ToBasicBlockNormalFormAux(e); }, it.second); updates.Set(it.first, Downcast(ret)); diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index b7f9cafbc7dc..5dc0198d9872 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -134,7 +134,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, } Expr VisitExpr_(const FunctionNode* op, const MCont& k) final { - ICHECK(!op->HasNonzeroAttr(attr::kPrimitive)) << "primitive func not supported yet."; + ICHECK(!op->attrs.HasNonzeroAttr(attr::kPrimitive)) << "primitive func not supported yet."; return k(ToCPS(GetRef(op), m, cm, vm, answer)); } diff --git a/src/target/build_common.h b/src/target/build_common.h index c66c2b52822e..313ab5fa72bb 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -50,18 +50,18 @@ inline std::unordered_map ExtractFuncInfo(co for (size_t i = 0; i < f->params.size(); ++i) { info.arg_types.push_back(f->params[i].dtype()); } - if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { + if (auto opt = f->attrs.GetAttr>(tir::attr::kDeviceThreadAxis)) { auto thread_axis = opt.value(); for (size_t i = 0; i < thread_axis.size(); ++i) { info.launch_param_tags.push_back(thread_axis[i]->thread_tag); } } - if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { + if (auto opt = f->attrs.GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { if (opt.value()) { info.launch_param_tags.push_back(runtime::kUseDynamicSharedMemoryTag); } } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); fmap[static_cast(global_symbol.value())] = info; } return fmap; diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index ab96d6e69d14..3e7c622a2f67 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -118,7 +118,7 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, void CodeGenCPU::AddFunction(const PrimFunc& f) { CodeGenLLVM::AddFunction(f); if (f_tvm_register_system_symbol_ != nullptr) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; export_system_symbols_.emplace_back( diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 9d324d56887f..f5ab55421ce9 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -47,7 +47,7 @@ namespace tvm { namespace codegen { static std::string get_name(const PrimFunc& f) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; return std::string(global_symbol.value()); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 6aabdc1bd804..f87c5e65fdbd 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -116,7 +116,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { << "Cannot codegen function with buffer_map, please lower them first"; std::vector param_types; - is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias); + is_restricted_ = f->attrs.HasNonzeroAttr(tir::attr::kNoAlias); for (Var param : f->params) { param_types.push_back(GetLLVMType(param)); if (!is_restricted_ && param.dtype().is_handle()) { @@ -129,7 +129,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { llvm::FunctionType* ftype = llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; ICHECK(module_->getFunction(static_cast(global_symbol.value())) == nullptr) diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 12c7a3132947..defccc7cd4ac 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -234,10 +234,10 @@ class LLVMModuleNode final : public runtime::ModuleNode { continue; } auto f = Downcast(kv.second); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()); function_names_.push_back(global_symbol.value()); - if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + if (f->attrs.HasNonzeroAttr(tir::attr::kIsEntryFunc)) { entry_func = global_symbol.value(); } funcs.push_back(f); diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 4a2917daa5ed..63fe53b52092 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -134,7 +134,7 @@ runtime::Module BuildCUDA(IRModule mod, Target target) { for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + auto calling_conv = f->attrs.GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index 17e38e9af6e6..e2125ccc9a52 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -43,7 +43,7 @@ runtime::Module BuildAOCL(IRModule mod, Target target, bool emulation) { for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodegenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + auto calling_conv = f->attrs.GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index f676f0f598d8..4691f439b864 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -77,10 +77,10 @@ void CodeGenC::AddFunction(const PrimFunc& f) { // reserve keywords ReserveKeywordsAsUnique(); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; - bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); + bool no_alias = f->attrs.HasNonzeroAttr(tir::attr::kNoAlias); this->PrintFuncPrefix(); this->PrintExtraAttrs(f); diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index dc849b8fa6b3..f40c716f4630 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -53,7 +53,7 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, std::string target_s void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } void CodeGenCHost::AddFunction(const PrimFunc& f) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; function_names_.push_back(global_symbol.value()); @@ -390,7 +390,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) { // Make sure that the executor function is the last one to be code generated so that all the // symbols are available to tvm_run_func auto fun_name = std::string(kv.first->name_hint); - bool is_aot_executor_fn = kv.second->GetAttr("runner_function", Bool(false)).value(); + bool is_aot_executor_fn = kv.second->attrs.GetAttr("runner_function", Bool(false)).value(); if (is_aot_executor_fn) { aot_executor_fn = Downcast(kv.second); diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index b44afec57d5d..25e7a8dab72a 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -58,7 +58,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { GetUniqueName("_"); // add to alloc buffer type. - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; @@ -130,7 +130,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { ICHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); ICHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); int work_dim = 0; - auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis).value(); + auto thread_axis = f->attrs.GetAttr>(tir::attr::kDeviceThreadAxis).value(); for (IterVar iv : thread_axis) { runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); @@ -340,7 +340,7 @@ runtime::Module BuildMetal(IRModule mod, Target target) { CodeGenMetal cg(target); cg.Init(output_ssa); auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + auto calling_conv = f->attrs.GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index edb614d9c122..eb80d32017fc 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -292,7 +292,7 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) { CodeGenOpenCL cg; cg.Init(output_ssa); auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + auto calling_conv = f->attrs.GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 9896d8b833f9..8c2b6b25ed0b 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -148,7 +148,7 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) { for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenVHLS: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + auto calling_conv = f->attrs.GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); @@ -170,7 +170,7 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) { code = (*f)(code).operator std::string(); } - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; kernel_info.push_back({global_symbol.value(), code}); diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index e922942e8acf..561103305bc4 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -113,10 +113,10 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); + auto calling_conv = f->attrs.GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 42d0027a326f..275bb1364456 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -41,7 +41,7 @@ CodeGenSPIRV::CodeGenSPIRV(Target target) : spirv_support_(target) {} runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { this->InitFuncState(); - ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; + ICHECK(f->attrs.HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; uint32_t i_buffer = 0; diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 402e3291975f..c40e15bee92f 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -516,7 +516,7 @@ runtime::Module BuildStackVM(IRModule mod, Target target) { for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenStackVM: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); @@ -524,7 +524,7 @@ runtime::Module BuildStackVM(IRModule mod, Target target) { ICHECK(!fmap.count(f_name)) << "Function name " << f_name << "already exist in list"; fmap[f_name] = std::move(vm); - if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + if (f->attrs.HasNonzeroAttr(tir::attr::kIsEntryFunc)) { entry_func = f_name; } } diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 2089ead98168..14e5e7f9973a 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -169,10 +169,10 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Interface of VerifyMemory pass std::vector VerifyMemory_(const PrimFunc& func) { - auto target = func->GetAttr(tvm::attr::kTarget); + auto target = func->attrs.GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; - if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + if (func->attrs.GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { MemoryAccessVerifier v(func, target.value()->kind->device_type); v.Run(); diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index a41905c148bf..5603247f6646 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -245,7 +245,7 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region } Bool IsFromLegacyTESchedule(PrimFunc f) { - Optional from_legacy_te_schedule = f->GetAttr("from_legacy_te_schedule", Bool(false)); + Optional from_legacy_te_schedule = f->attrs.GetAttr("from_legacy_te_schedule", Bool(false)); return from_legacy_te_schedule.value(); } diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 21f1b18d523b..12b231c3bf2c 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -207,7 +207,7 @@ namespace transform { Pass LowerCustomDatatypes() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - auto target = f->GetAttr(tvm::attr::kTarget); + auto target = f->attrs.GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute"; n->body = CustomDatatypesLowerer(target.value()->kind->name)(std::move(n->body)); diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 2555002d29b0..6414fc1819b8 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -297,9 +297,10 @@ namespace transform { Pass LowerIntrin() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - auto target = f->GetAttr(tvm::attr::kTarget); + auto target = f->attrs.GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; + // TODO(@electriclilies): This is most likely a problem.. auto mtriple = target.value()->GetAttr("mtriple", ""); n->body = IntrinInjecter(&analyzer, target.value()->kind->name, mtriple.value())(std::move(n->body)); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 481b1bfd4b19..572ce88ba012 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -600,7 +600,7 @@ namespace transform { Pass LowerThreadAllreduce() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - auto target = f->GetAttr(tvm::attr::kTarget); + auto target = f->attrs.GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; const TargetNode* target_node = target.as(); ThreadAllreduceBuilder thread_all_reduce(target_node); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 30ec148c37dd..37ded00bfc0f 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -387,7 +387,7 @@ namespace transform { Pass LowerWarpMemory() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - auto target = f->GetAttr(tvm::attr::kTarget); + auto target = f->attrs.GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); WarpMemoryRewriter warp_memory_rewriter(warp_size); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 393ce6c286b4..2f0c389c0373 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -106,10 +106,10 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { } PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; - auto target = func->GetAttr(tvm::attr::kTarget); + auto target = func->attrs.GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute"; int target_device_type = target.value()->kind->device_type; @@ -294,7 +294,7 @@ Pass MakePackedAPI(int num_unpacked_args) { for (const auto& kv : mptr->functions) { if (auto* n = kv.second.as()) { PrimFunc func = GetRef(n); - if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + if (func->attrs.GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args); updates.push_back({kv.first, updated_func}); diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index 6e8793fbd367..d08f3320accb 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -41,10 +41,10 @@ namespace tvm { namespace tir { PrimFunc MakeUnpackedAPI(PrimFunc&& func) { - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol) << "MakeUnpackedAPI: Expect PrimFunc to have the global_symbol attribute"; - auto target = func->GetAttr(tvm::attr::kTarget); + auto target = func->attrs.GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "MakeUnpackedAPI: Require the target attribute"; auto* func_ptr = func.CopyOnWrite(); @@ -111,7 +111,7 @@ Pass MakeUnpackedAPI() { for (const auto& kv : mptr->functions) { if (auto* n = kv.second.as()) { PrimFunc func = GetRef(n); - if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + if (func->attrs.GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { auto updated_func = MakeUnpackedAPI(std::move(func)); updates.push_back({kv.first, updated_func}); diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index e101e6b904ce..61867d18923e 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -75,7 +75,7 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) tmap[kv.first] = kv.second; } - auto opt_thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); + auto opt_thread_axis = f->attrs.GetAttr>(tir::attr::kDeviceThreadAxis); ICHECK(opt_thread_axis != nullptr) << "Require attribute " << tir::attr::kDeviceThreadAxis; auto thread_axis = opt_thread_axis.value(); auto* n = f.CopyOnWrite(); diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 795ae9d6a73a..05be4c06ffaa 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -311,9 +311,9 @@ class HostDeviceSplitter : public StmtMutator { }; PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { - auto target = func->GetAttr(tvm::attr::kTarget); + auto target = func->attrs.GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; - auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->attrs.GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 2c32cc7f0883..98c9d5b4f5fa 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -501,7 +501,7 @@ class StorageFlattener : public StmtExprMutator { PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { // Only apply this pass to TIR from TE schedules - Optional from_legacy_te_schedule = func->GetAttr("from_legacy_te_schedule", Bool(false)); + Optional from_legacy_te_schedule = func->attrs.GetAttr("from_legacy_te_schedule", Bool(false)); if (from_legacy_te_schedule.value()) { auto fptr = func.CopyOnWrite(); From a3ae6cb5232ca3a1712b37dfa736cd2e023b1db5 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 13 Aug 2021 18:12:15 -0700 Subject: [PATCH 2/9] lint --- include/tvm/target/target.h | 2 +- src/relay/analysis/context_analysis.cc | 4 +++- src/relay/backend/vm/lambda_lift.cc | 4 +++- src/target/source/codegen_c_host.cc | 3 ++- src/target/spirv/codegen_spirv.cc | 3 ++- src/tir/transforms/storage_flatten.cc | 3 ++- 6 files changed, 13 insertions(+), 6 deletions(-) diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 614ff939c8ab..362737852dc9 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -54,7 +54,7 @@ class TargetNode : public Object { /*! \brief Keys for this target */ Array keys; /*! \brief Collection of attributes */ - Map attrs; // TODO(@electriclilies): Unify with DictAttrs on IRModule + Map attrs; // TODO(@electriclilies): Unify with DictAttrs on IRModule /*! * \brief The raw string representation of the target * \return the full device string to pass to codegen::Build diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc index cb785c138912..66bcfb098fce 100644 --- a/src/relay/analysis/context_analysis.cc +++ b/src/relay/analysis/context_analysis.cc @@ -432,7 +432,9 @@ class ContextAnalyzer : public MixedModeVisitor { } // Check if a function is a closure. - bool IsClosure(const Function& func) { return func->attrs.GetAttr(attr::kClosure, 0) != 0; } + bool IsClosure(const Function& func) { + return func->attrs.GetAttr(attr::kClosure, 0) != 0; + } // Check if a function is a currying function. bool IsCurrying(const Function& func) { diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index bce1f4437b08..e5882159e1f0 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -44,7 +44,9 @@ inline std::string GenerateName(const Function& func) { return std::string("lifted_name") + std::to_string(hash); } -bool IsClosure(const Function& func) { return func->attrs.GetAttr(attr::kClosure, 0) != 0; } +bool IsClosure(const Function& func) { + return func->attrs.GetAttr(attr::kClosure, 0) != 0; +} Function MarkClosure(Function func) { return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1)); diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index f40c716f4630..eaed10c99a9c 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -390,7 +390,8 @@ runtime::Module BuildCHost(IRModule mod, Target target) { // Make sure that the executor function is the last one to be code generated so that all the // symbols are available to tvm_run_func auto fun_name = std::string(kv.first->name_hint); - bool is_aot_executor_fn = kv.second->attrs.GetAttr("runner_function", Bool(false)).value(); + bool is_aot_executor_fn = + kv.second->attrs.GetAttr("runner_function", Bool(false)).value(); if (is_aot_executor_fn) { aot_executor_fn = Downcast(kv.second); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 275bb1364456..1f325684c230 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -41,7 +41,8 @@ CodeGenSPIRV::CodeGenSPIRV(Target target) : spirv_support_(target) {} runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { this->InitFuncState(); - ICHECK(f->attrs.HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; + ICHECK(f->attrs.HasNonzeroAttr(tir::attr::kNoAlias)) + << "SPIRV only takes restricted memory model"; std::vector pod_args; uint32_t i_buffer = 0; diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 98c9d5b4f5fa..9e9dde43aa4d 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -501,7 +501,8 @@ class StorageFlattener : public StmtExprMutator { PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { // Only apply this pass to TIR from TE schedules - Optional from_legacy_te_schedule = func->attrs.GetAttr("from_legacy_te_schedule", Bool(false)); + Optional from_legacy_te_schedule = + func->attrs.GetAttr("from_legacy_te_schedule", Bool(false)); if (from_legacy_te_schedule.value()) { auto fptr = func.CopyOnWrite(); From 6fa3351a48d4b813b1e854ac214e9bc57587c710 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 13 Aug 2021 18:20:55 -0700 Subject: [PATCH 3/9] Fix documentation --- include/tvm/ir/attrs.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 3aaef82e89b9..317d8be5a16c 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -295,7 +295,7 @@ class DictAttrs : public Attrs { * * \endcode */ - bool HasNonzeroAttr(const std::string& attr_key) const { + bool zeroAttr(const std::string& attr_key) const { return GetAttr(attr_key, 0) != 0; } @@ -344,17 +344,17 @@ inline TAttrs AttrsWithDefaultValues() { * \endcode */ template -inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) { +inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); - TNode* node = func.CopyOnWrite(); + TNode* node = input.CopyOnWrite(); if (node->attrs.defined()) { node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); } else { Map dict = {{attr_key, attr_value}}; node->attrs = DictAttrs(dict); } - return func; + return input; } // Namespace containing detail implementations From 493dda4cbdf89501157700a54d37bc77f0683e1e Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 13 Aug 2021 18:27:53 -0700 Subject: [PATCH 4/9] fix typo --- include/tvm/ir/attrs.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 317d8be5a16c..e0eac1506724 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -295,7 +295,7 @@ class DictAttrs : public Attrs { * * \endcode */ - bool zeroAttr(const std::string& attr_key) const { + bool HasNonzeroAttr(const std::string& attr_key) const { return GetAttr(attr_key, 0) != 0; } @@ -354,7 +354,7 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v Map dict = {{attr_key, attr_value}}; node->attrs = DictAttrs(dict); } - return input; + return func; } // Namespace containing detail implementations From 25d6a85e7ee8a428620190b8123bbc7709735ec3 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Fri, 13 Aug 2021 18:34:30 -0700 Subject: [PATCH 5/9] Another typo! --- include/tvm/ir/attrs.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index e0eac1506724..74c98da00189 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -354,7 +354,7 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v Map dict = {{attr_key, attr_value}}; node->attrs = DictAttrs(dict); } - return func; + return input; } // Namespace containing detail implementations From 13b72b0d65b66a94395311bc96275d54e322905e Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 16 Aug 2021 11:03:32 -0700 Subject: [PATCH 6/9] Revert GetAttrs to ->attrs.GetAttrs change --- include/tvm/ir/attrs.h | 108 ------------------ include/tvm/ir/function.h | 103 +++++++++++++++++ include/tvm/ir/module.h | 2 - include/tvm/target/target.h | 3 +- src/driver/driver_api.cc | 4 +- src/relay/analysis/context_analysis.cc | 6 +- src/relay/analysis/extract_fused_functions.cc | 2 +- src/relay/analysis/feature.cc | 2 +- src/relay/analysis/get_calibration_data.cc | 6 +- src/relay/backend/aot_executor_codegen.cc | 2 +- src/relay/backend/compile_engine.cc | 10 +- .../contrib/arm_compute_lib/codegen.cc | 2 +- src/relay/backend/contrib/bnns/codegen.cc | 4 +- .../contrib/codegen_json/codegen_json.h | 6 +- src/relay/backend/contrib/dnnl/codegen.cc | 2 +- src/relay/backend/contrib/ethosn/codegen.cc | 4 +- src/relay/backend/graph_executor_codegen.cc | 4 +- src/relay/backend/graph_plan_memory.cc | 2 +- src/relay/backend/interpreter.cc | 4 +- src/relay/backend/te_compiler.cc | 22 ++-- src/relay/backend/utils.h | 6 +- src/relay/backend/vm/compiler.cc | 8 +- src/relay/backend/vm/inline_primitives.cc | 6 +- src/relay/backend/vm/lambda_lift.cc | 10 +- src/relay/ir/transform.cc | 4 +- src/target/build_common.h | 6 +- src/target/llvm/codegen_cpu.cc | 2 +- src/target/llvm/codegen_hexagon.cc | 2 +- src/target/llvm/codegen_llvm.cc | 4 +- src/target/llvm/llvm_module.cc | 4 +- src/target/opt/build_cuda_on.cc | 2 +- src/target/source/codegen_aocl.cc | 2 +- src/target/source/codegen_c.cc | 4 +- src/target/source/codegen_c_host.cc | 5 +- src/target/source/codegen_metal.cc | 6 +- src/target/source/codegen_opencl.cc | 2 +- src/target/source/codegen_vhls.cc | 4 +- src/target/spirv/build_vulkan.cc | 4 +- src/target/spirv/codegen_spirv.cc | 3 +- src/target/stackvm/codegen_stackvm.cc | 4 +- src/tir/analysis/verify_memory.cc | 4 +- src/tir/transforms/ir_utils.cc | 2 +- src/tir/transforms/lower_custom_datatypes.cc | 2 +- src/tir/transforms/lower_intrin.cc | 3 +- src/tir/transforms/lower_thread_allreduce.cc | 2 +- src/tir/transforms/lower_warp_memory.cc | 2 +- src/tir/transforms/make_packed_api.cc | 6 +- src/tir/transforms/make_unpacked_api.cc | 6 +- src/tir/transforms/remap_thread_axis.cc | 2 +- src/tir/transforms/split_host_device.cc | 4 +- src/tir/transforms/storage_flatten.cc | 3 +- 51 files changed, 203 insertions(+), 219 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 74c98da00189..da7bc12619bd 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -214,7 +214,6 @@ class DictAttrsNode : public BaseAttrsNode { void VisitNonDefaultAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; Array ListFieldInfo() const final; - // type info static constexpr const char* _type_key = "DictAttrs"; TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); @@ -233,72 +232,6 @@ class DictAttrs : public Attrs { */ TVM_DLL explicit DictAttrs(Map dict); - // Utils for accessing attributes - // This needs to be on DictAttrs, not DictAttrsNode because we return the default - // value if DictAttrsNode is not defined. - /*! - * \brief Get a function attribute. - * - * \param attr_key The attribute key. - * \param default_value The default value if the key does not exist, defaults to nullptr. - * - * \return The result - * - * \tparam TOBjectRef the expected object type. - * \throw Error if the key exists but the value does not match TObjectRef - * - * \code - * - * void GetAttrExample(const BaseFunc& f) { - * auto value = f->attrs.GetAttr("AttrKey", 0); - * } - * - * \endcode - */ - template - Optional GetAttr( - const std::string& attr_key, - Optional default_value = Optional(nullptr)) const { - static_assert(std::is_base_of::value, - "Can only call GetAttr with ObjectRef types."); - if (!defined()) return default_value; - const DictAttrsNode* node = this->as(); - - auto it = node->dict.find(attr_key); - if (it != node->dict.end()) { - return Downcast>((*it).second); - } else { - return default_value; - } - } - // variant that uses TObjectRef to enable implicit conversion to default value. - template - Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { - return GetAttr(attr_key, Optional(default_value)); - } - /*! - * \brief Check whether the function has an non-zero integer attr. - * - * This function can be used to check whether an optional - * attribute mark(e.g. inline) exists. - * - * \param attr_key The key to the attribute. - * \return The check result. - * - * \code - * - * void HasNonzeroAttrExample(const BaseFunc& f) { - * if (f->attrs.HasNonzeroAttr(attr::kInline)) { - * // inline the function. - * } - * } - * - * \endcode - */ - bool HasNonzeroAttr(const std::string& attr_key) const { - return GetAttr(attr_key, 0) != 0; - } - TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; @@ -316,47 +249,6 @@ inline TAttrs AttrsWithDefaultValues() { return TAttrs(n); } -/*! - * \brief Copy the function or module, but overrides - * the attribute value key with the value. - * - * \param input The thing to annotate (BaseFunc or IRModule) - * \param attr_key The attribute key. - * \param attr_value The value attribute value. - * - * \tparam TFunc The corresponding function or module type. - * - * \returns The new function or module with updated attributes. - * - * \note This function performs copy on write optimization for func and module. - * If we move a uniquely referenced func or module into WithAttr, - * then no additional copy will be performed. - * - * This is also why we make it as a function instead of a member function - * and why we pass by value in the first argument. - * - * \code - * - * // Recommended way to trigger copy on write - * func = WithAttr(std::move(func), "key1", value1); - * func = WithAttr(std::move(func), "key2", value2); - * - * \endcode - */ -template -inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) { - using TNode = typename TFunc::ContainerType; - static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); - TNode* node = input.CopyOnWrite(); - if (node->attrs.defined()) { - node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); - } else { - Map dict = {{attr_key, attr_value}}; - node->attrs = DictAttrs(dict); - } - return input; -} - // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 020b3de77ab3..09c074cb71bd 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -79,6 +79,67 @@ class BaseFuncNode : public RelayExprNode { /*! \brief Additional attributes storing the meta-data */ DictAttrs attrs; + /*! + * \brief Get a function attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const BaseFunc& f) { + * auto value = f->GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + static_assert(std::is_base_of::value, + "Can only call GetAttr with ObjectRef types."); + if (!attrs.defined()) return default_value; + auto it = attrs->dict.find(attr_key); + if (it != attrs->dict.end()) { + return Downcast>((*it).second); + } else { + return default_value; + } + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + /*! + * \brief Check whether the function has an non-zero integer attr. + * + * This function can be used to check whether an optional + * attribute mark(e.g. inline) exists. + * + * \param attr_key The key to the attribute. + * \return The check result. + * + * \code + * + * void HasNonzeroAttrExample(const BaseFunc& f) { + * if (f->HasNonzeroAttr(attr::kInline)) { + * // inline the function. + * } + * } + * + * \endcode + */ + bool HasNonzeroAttr(const std::string& attr_key) const { + return GetAttr(attr_key, 0) != 0; + } + static constexpr const char* _type_key = "BaseFunc"; static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode); @@ -93,6 +154,48 @@ class BaseFunc : public RelayExpr { TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); }; +/*! + * \brief Create a new function that copies func, but overrides + * the attribute value key with the value. + * + * \param func The input function. + * \param attr_key The attribute key. + * \param attr_value The value attribute value. + * + * \tparam TFunc The corresponding function type. + * + * \returns The new function with updated attributes. + * + * \note This function performs copy on write optimization for func. + * If we move a uniquely referenced func into WithAttr, + * then no additional copy will be performed. + * + * This is also why we make it as a function instead of a member function + * and why we pass by value in the first argument. + * + * \code + * + * // Recommended way to trigger copy on write + * func = WithAttr(std::move(func), "key1", value1); + * func = WithAttr(std::move(func), "key2", value2); + * + * \endcode + */ +template ::value>::type> +inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) { + using TNode = typename TFunc::ContainerType; + static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); + TNode* node = func.CopyOnWrite(); + if (node->attrs.defined()) { + node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); + } else { + Map dict = {{attr_key, attr_value}}; + node->attrs = DictAttrs(dict); + } + return func; +} + /*! * \brief Generic attribute names that can be attached to any function. * diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index c88f438c7cd8..638f132e3179 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -58,8 +58,6 @@ class IRModuleNode : public Object { Map type_definitions; /*! \brief The source map for the module. */ parser::SourceMap source_map; - /* \brief Additional attributes storing meta-data about the module. */ - DictAttrs attrs; IRModuleNode() : source_map() {} diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 362737852dc9..9c1fe55749e4 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -54,7 +54,7 @@ class TargetNode : public Object { /*! \brief Keys for this target */ Array keys; /*! \brief Collection of attributes */ - Map attrs; // TODO(@electriclilies): Unify with DictAttrs on IRModule + Map attrs; /*! * \brief The raw string representation of the target * \return the full device string to pass to codegen::Build @@ -101,7 +101,6 @@ class TargetNode : public Object { * \param default_value The value returned if the key is not present * \return An optional, NullOpt if not found, otherwise the value found */ - // TODO(@electriclilies): Remove once we have removed the target attrs template Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { return GetAttr(attr_key, Optional(default_value)); diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 9c335d95b156..d6af9936ca40 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -401,7 +401,7 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target auto host_pass_list = { Filter([](const tir::PrimFunc& f) { - return f->attrs.GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; }), BindTarget(target_host), @@ -418,7 +418,7 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target // device pipeline auto device_pass_list = { Filter([](const tir::PrimFunc& f) { - return f->attrs.GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch; }), BindTarget(target), diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc index 66bcfb098fce..35813f67d094 100644 --- a/src/relay/analysis/context_analysis.cc +++ b/src/relay/analysis/context_analysis.cc @@ -329,7 +329,7 @@ class ContextAnalyzer : public MixedModeVisitor { auto func = GetRef(fn); // No need to step into fused primitive functions as they are handled as // a whole. - if (fn->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (fn->HasNonzeroAttr(attr::kPrimitive)) { return; } @@ -432,9 +432,7 @@ class ContextAnalyzer : public MixedModeVisitor { } // Check if a function is a closure. - bool IsClosure(const Function& func) { - return func->attrs.GetAttr(attr::kClosure, 0) != 0; - } + bool IsClosure(const Function& func) { return func->GetAttr(attr::kClosure, 0) != 0; } // Check if a function is a currying function. bool IsCurrying(const Function& func) { diff --git a/src/relay/analysis/extract_fused_functions.cc b/src/relay/analysis/extract_fused_functions.cc index 9ea2ba9828f1..e76b54e2d0b7 100644 --- a/src/relay/analysis/extract_fused_functions.cc +++ b/src/relay/analysis/extract_fused_functions.cc @@ -53,7 +53,7 @@ class FusedFunctionExtractorWrapper : private ExprVisitor { Map functions; void VisitExpr_(const FunctionNode* n) final { - if (n->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (n->HasNonzeroAttr(attr::kPrimitive)) { // Add function to functions, keyed by function hash string Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs); size_t hash_ = tvm::StructuralHash()(func); diff --git a/src/relay/analysis/feature.cc b/src/relay/analysis/feature.cc index 72964865fd1f..f72b4e105749 100644 --- a/src/relay/analysis/feature.cc +++ b/src/relay/analysis/feature.cc @@ -60,7 +60,7 @@ FeatureSet DetectFeature(const Expr& expr) { DETECT_DEFAULT_CONSTRUCT(Tuple) DETECT_DEFAULT_CONSTRUCT(TupleGetItem) DETECT_CONSTRUCT(Function, { - if (!op->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (!op->HasNonzeroAttr(attr::kPrimitive)) { ExprVisitor::VisitExpr_(op); } }) diff --git a/src/relay/analysis/get_calibration_data.cc b/src/relay/analysis/get_calibration_data.cc index 80460f9c52c5..12bab1e38ddd 100644 --- a/src/relay/analysis/get_calibration_data.cc +++ b/src/relay/analysis/get_calibration_data.cc @@ -55,7 +55,7 @@ class Collector : public ExprRewriter { ICHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined"; // we only handle functions with Compiler attribute set auto func = Downcast(module_->Lookup(var)); - if (func->attrs.GetAttr(attr::kCompiler)) { + if (func->GetAttr(attr::kCompiler)) { // collect all the inputs and outputs for (const auto& it : call->args) new_outputs_.push_back(it); new_outputs_.push_back(post); @@ -110,7 +110,7 @@ IRModule GetCalibrateModule(IRModule module) { for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); - if (func->attrs.GetAttr(attr::kCompiler)) { + if (func->GetAttr(attr::kCompiler)) { // we need to inline the functions in order to run grpah runtime func = WithAttr(std::move(func), attr::kInline, tvm::Integer(1)); // reset the compiler attribute to null for llvm execution @@ -145,7 +145,7 @@ class OutputMapper : public ExprRewriter { << "Repeated function call " << var << " is not supported."; auto func = Downcast(module_->Lookup(var)); // we only handle functions with Compiler attribute set - if (func->attrs.GetAttr(attr::kCompiler)) { + if (func->GetAttr(attr::kCompiler)) { Array info; // the first value is the offset info.push_back(Integer(*offset_)); diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index a9b381c0461f..54a10add2f07 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -440,7 +440,7 @@ class AOTExecutorCodegen : public ExprVisitor { void VisitExpr_(const GlobalVarNode* op) override { throw std::runtime_error(""); } void VisitExpr_(const IfNode* op) override { throw std::invalid_argument("if not supported"); } void VisitExpr_(const FunctionNode* op) override { - ICHECK(op->attrs.GetAttr(attr::kCompiler).defined()) + ICHECK(op->GetAttr(attr::kCompiler).defined()) << "FunctionNode only supported by custom codegen"; } void VisitExpr_(const RefCreateNode* op) override { diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index beea25efd940..6142e8323dea 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -89,13 +89,13 @@ class CompileEngineImpl : public CompileEngineNode { auto src_func = it.first->source_func; ICHECK(src_func.defined()); - if (src_func->attrs.GetAttr(attr::kCompiler).defined()) { - auto code_gen = src_func->attrs.GetAttr(attr::kCompiler); + if (src_func->GetAttr(attr::kCompiler).defined()) { + auto code_gen = src_func->GetAttr(attr::kCompiler); ICHECK(code_gen.defined()) << "No external codegen is set"; std::string code_gen_name = code_gen.value(); cached_ext_funcs.push_back(it.first); - auto symbol_name = src_func->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false) << "\n" << "Functions with external codegen must have the " @@ -186,9 +186,9 @@ class CompileEngineImpl : public CompileEngineNode { // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. - if (key->source_func->attrs.GetAttr(attr::kCompiler).defined()) { + if (key->source_func->GetAttr(attr::kCompiler).defined()) { auto ir_module = IRModule(); - const auto name_node = key->source_func->attrs.GetAttr(tvm::attr::kGlobalSymbol); + const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "External function has not been attached a name yet."; auto func_name = std::string(name_node.value()); auto target = Target("ext_dev"); diff --git a/src/relay/backend/contrib/arm_compute_lib/codegen.cc b/src/relay/backend/contrib/arm_compute_lib/codegen.cc index 0270b01ab8ff..8098c8d51274 100644 --- a/src/relay/backend/contrib/arm_compute_lib/codegen.cc +++ b/src/relay/backend/contrib/arm_compute_lib/codegen.cc @@ -87,7 +87,7 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { << cn->op->GetTypeKey(); } auto fn = cn->op.as(); - auto comp = fn->attrs.GetAttr(attr::kComposite); + auto comp = fn->GetAttr(attr::kComposite); ICHECK(comp.defined()) << "Arm Compute Library JSON runtime only supports composite functions."; const std::string name = comp.value(); std::shared_ptr json_node; diff --git a/src/relay/backend/contrib/bnns/codegen.cc b/src/relay/backend/contrib/bnns/codegen.cc index 6464d015ebc1..72c32fb5b19e 100644 --- a/src/relay/backend/contrib/bnns/codegen.cc +++ b/src/relay/backend/contrib/bnns/codegen.cc @@ -80,7 +80,7 @@ class BNNSJSONSerializer : public backend::contrib::JSONSerializer { if (const auto* op_node = cn->op.as()) { name = op_node->name; } else if (const auto* fn = cn->op.as()) { - auto comp = fn->attrs.GetAttr(attr::kComposite); + auto comp = fn->GetAttr(attr::kComposite); ICHECK(comp.defined()) << "BNNS JSON runtime only supports composite functions."; name = comp.value(); @@ -176,7 +176,7 @@ struct BNNSConstantUpdater : public ConstantUpdater { private: bool isBNNSSpecificCompositeFunc(const FunctionNode* op) { - auto comp = op->attrs.GetAttr(attr::kComposite); + auto comp = op->GetAttr(attr::kComposite); if (!comp) return false; auto comp_name = comp.value(); diff --git a/src/relay/backend/contrib/codegen_json/codegen_json.h b/src/relay/backend/contrib/codegen_json/codegen_json.h index 0366f8d2b838..4966f3f01c7d 100644 --- a/src/relay/backend/contrib/codegen_json/codegen_json.h +++ b/src/relay/backend/contrib/codegen_json/codegen_json.h @@ -225,7 +225,7 @@ class JSONSerializer : public MemoizedExprTranslatorattrs.get(); extractor.Extract(const_cast(call_attr)); } else if (const auto* fn = cn->op.as()) { - auto pattern = fn->attrs.GetAttr(attr::kPartitionedFromPattern); + auto pattern = fn->GetAttr(attr::kPartitionedFromPattern); ICHECK(pattern.defined()); std::vector values; values.push_back(pattern.value()); @@ -267,7 +267,7 @@ class JSONSerializer : public MemoizedExprTranslatorop.as()) { name = op_node->name; } else if (const auto* fn = cn->op.as()) { - auto comp = fn->attrs.GetAttr(attr::kComposite); + auto comp = fn->GetAttr(attr::kComposite); ICHECK(comp.defined()) << "JSON runtime only supports composite functions."; name = comp.value(); } else { @@ -298,7 +298,7 @@ class JSONSerializer : public MemoizedExprTranslator VisitExpr_(const FunctionNode* fn) { - ICHECK(fn->attrs.GetAttr(attr::kComposite).defined()) + ICHECK(fn->GetAttr(attr::kComposite).defined()) << "JSON runtime only supports composite functions"; // FunctionNode should be handled by the caller. return {}; diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index cfde6550a431..e96255e976e9 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -440,7 +440,7 @@ class DNNLJSONSerializer : public backend::contrib::JSONSerializer { if (const auto* op_node = cn->op.as()) { name = op_node->name; } else if (const auto* fn = cn->op.as()) { - auto comp = fn->attrs.GetAttr(attr::kComposite); + auto comp = fn->GetAttr(attr::kComposite); ICHECK(comp.defined()) << "DNNL JSON runtime only supports composite functions."; name = comp.value(); diff --git a/src/relay/backend/contrib/ethosn/codegen.cc b/src/relay/backend/contrib/ethosn/codegen.cc index f3343c5e2648..97b308e51e18 100644 --- a/src/relay/backend/contrib/ethosn/codegen.cc +++ b/src/relay/backend/contrib/ethosn/codegen.cc @@ -53,7 +53,7 @@ bool IsEthosnFunc(const Call& call, const std::string& op_name) { if (call->op->IsInstance()) { Function func = Downcast(call->op); ICHECK(func.defined()); - auto name_node = func->attrs.GetAttr(attr::kComposite); + auto name_node = func->GetAttr(attr::kComposite); return name_node.value() == op_name; } return false; @@ -521,7 +521,7 @@ runtime::Module EthosnCompiler::CreateRuntimeModule(const ObjectRef& ref) { if (ref->IsInstance()) { IRModule mod; Function func = Downcast(ref); - auto name_node = func->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "Failed to retrieved external symbol."; GlobalVar gvar = GlobalVar(name_node.value()); mod->Add(gvar, func); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index bbd80eccce39..cc54a52be200 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -226,7 +226,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorattrs.GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { UpdateConstants(func, ¶ms_); } @@ -473,7 +473,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator VisitExpr_(const FunctionNode* op) override { - ICHECK(op->attrs.GetAttr(attr::kCompiler).defined()) + ICHECK(op->GetAttr(attr::kCompiler).defined()) << "Only functions supported by custom codegen"; return {}; } diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index a4fad4bfe4f5..93c823d8a007 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -345,7 +345,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { */ static bool IsReshape(const CallNode* call) { if (const auto* fn = call->op.as()) { - return fn->attrs.HasNonzeroAttr(attr::kReshapeOnly); + return fn->HasNonzeroAttr(attr::kReshapeOnly); } if (call->attrs.defined()) { diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 3970d6c041d9..6ebb17e93eca 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -480,7 +480,7 @@ class Interpreter : public ExprFunctor, bool is_dyn = IsDynamic(ret_type); if (is_dyn) { - ICHECK(func->attrs.HasNonzeroAttr(attr::kPrimitive)); + ICHECK(func->HasNonzeroAttr(attr::kPrimitive)); out_shapes = ComputeDynamicShape(func, args); } @@ -519,7 +519,7 @@ class Interpreter : public ExprFunctor, ObjectRef Invoke(const InterpreterClosure& closure, const tvm::Array& args, const Var& bind = Var()) { // Get a reference to the function inside the closure. - if (closure->func->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (closure->func->HasNonzeroAttr(attr::kPrimitive)) { return InvokePrimitiveOp(closure->func, args); } auto func = closure->func; diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 07fbc6d68265..7840960ec268 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -111,12 +111,12 @@ class TECompilerImpl : public TECompilerNode { for (const auto& it : cache_) { auto src_func = it.first->source_func; ICHECK(src_func.defined()); - if (src_func->attrs.GetAttr(attr::kCompiler).defined()) { - auto code_gen = src_func->attrs.GetAttr(attr::kCompiler); + if (src_func->GetAttr(attr::kCompiler).defined()) { + auto code_gen = src_func->GetAttr(attr::kCompiler); std::string code_gen_name = code_gen.value(); cached_ext_funcs.push_back(it.first); - auto symbol_name = src_func->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); @@ -187,9 +187,9 @@ class TECompilerImpl : public TECompilerNode { // No need to lower external functions for now. We will invoke the external // codegen tool once and lower all functions together. - if (key->source_func->attrs.GetAttr(attr::kCompiler).defined()) { + if (key->source_func->GetAttr(attr::kCompiler).defined()) { auto ir_module = IRModule(); - const auto name_node = key->source_func->attrs.GetAttr(tvm::attr::kGlobalSymbol); + const auto name_node = key->source_func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "External function has not been attached a name yet."; auto func_name = GetUniqueName(name_node.value(), &name_map_); auto target = Target("ext_dev"); @@ -325,7 +325,7 @@ class LowerTensorExpr : public ExprMutator { return ExprMutator::VisitExpr_(call); } - if (!func->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (!func->HasNonzeroAttr(attr::kPrimitive)) { // Provide a callback hook which allows one-level up code generators to // act when we process a function. this->process_fn(func); @@ -340,7 +340,7 @@ class LowerTensorExpr : public ExprMutator { Target target; - if (func->attrs.GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { target = Target("ext_dev"); CCacheKey key = CCacheKey(func, target); CachedFunc ext_func = compiler_->Lower(key, module_name_); @@ -390,7 +390,7 @@ class LowerTensorExpr : public ExprMutator { this->process_fn(func_with_metadata); auto tir_call_attrs = make_object(); - if (func->attrs.HasNonzeroAttr(attr::kReshapeOnly)) { + if (func->HasNonzeroAttr(attr::kReshapeOnly)) { tir_call_attrs->metadata.Set(attr::kReshapeOnly, tvm::Integer(1)); } @@ -604,13 +604,13 @@ void UpdateFunctionMetadata(Function relay_func, Map relay_primfuncs; Optional> prim_fns = - relay_func->attrs.GetAttr>("prim_funcs"); + relay_func->GetAttr>("prim_funcs"); CHECK(prim_fns) << "primitive functions not set on Relay function by TECompiler."; - Optional prim_fn_var = relay_func->attrs.GetAttr("prim_fn_var"); + Optional prim_fn_var = relay_func->GetAttr("prim_fn_var"); CHECK(prim_fn_var) << "prim_fn_var must be set on Relay functions by TECompiler."; - Optional relay_target = relay_func->attrs.GetAttr("target"); + Optional relay_target = relay_func->GetAttr("target"); CHECK(relay_target) << "target must be set on Relay functions by the TECompiler."; for (const auto& kv : prim_fns.value()) { diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 615d212e0169..a0c7a5aad26d 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -173,10 +173,10 @@ struct ConstantUpdater : public ExprVisitor { */ inline void UpdateConstants(Function func, std::unordered_map* params) { - auto codegen = func->attrs.GetAttr(attr::kCompiler); + auto codegen = func->GetAttr(attr::kCompiler); ICHECK(codegen.defined()) << "No external codegen is set"; std::string codegen_name = codegen.value(); - const auto name_node = func->attrs.GetAttr(tvm::attr::kGlobalSymbol); + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); std::string symbol = std::string(name_node.value()); std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater"; // Get the constant updater for the external codegen @@ -392,7 +392,7 @@ inline const CallNode* GetRootCall(const CallNode* current_call, int depth, * \return An external symbol. */ inline std::string GetExtSymbol(const Function& func) { - const auto name_node = func->attrs.GetAttr(tvm::attr::kGlobalSymbol); + const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(name_node.defined()) << "Fail to retrieve external symbol."; return std::string(name_node.value()); } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 62c45ff0168f..b3eab91d202c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -501,7 +501,7 @@ class VMFunctionCompiler : ExprFunctor { void EmitInvokeTVMOp(const Function& func, const Expr& inputs, const Expr& outputs) { std::vector argument_registers; - ICHECK(func->attrs.GetAttr(attr::kPrimitive, 0) != 0) + ICHECK(func->GetAttr(attr::kPrimitive, 0) != 0) << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; auto input_tuple = inputs.as(); @@ -526,7 +526,7 @@ class VMFunctionCompiler : ExprFunctor { Target target; - if (func->attrs.GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { target = Target("ext_dev"); } else { // Next generate the invoke instruction. @@ -553,7 +553,7 @@ class VMFunctionCompiler : ExprFunctor { auto cfunc = context_->compiler->Lower(key, mangle_fn); auto op_index = -1; - if (func->attrs.GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { op_index = context_->cached_funcs.size(); context_->cached_funcs.push_back(cfunc); } else { @@ -765,7 +765,7 @@ class VMFunctionCompiler : ExprFunctor { } void VisitExpr_(const FunctionNode* func_node) { - if (!func_node->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (!func_node->HasNonzeroAttr(attr::kPrimitive)) { LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl << "Program: " << AsText(GetRef(func_node), false) << std::endl << "AST: " << GetRef(func_node); diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 48976384c50a..05fb2a120620 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -97,7 +97,7 @@ struct PrimitiveInliner : ExprMutator { } if (auto func = op.as()) { - if (func->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (func->HasNonzeroAttr(attr::kPrimitive)) { tvm::Array call_args; for (auto arg : call->args) { auto new_arg = VisitExpr(arg); @@ -120,7 +120,7 @@ struct PrimitiveInliner : ExprMutator { } Expr VisitExpr_(const FunctionNode* func) { - if (func->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (func->HasNonzeroAttr(attr::kPrimitive)) { return GetRef(func); } else { return ExprMutator::VisitExpr_(func); @@ -133,7 +133,7 @@ struct PrimitiveInliner : ExprMutator { auto global = pair.first; auto base_func = pair.second; if (auto* n = base_func.as()) { - if (n->attrs.GetAttr(attr::kCompiler).defined()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); DLOG(INFO) << "Before inlining primitives: " << global << std::endl << AsText(func, false); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index e5882159e1f0..c768a2c300ec 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -44,9 +44,7 @@ inline std::string GenerateName(const Function& func) { return std::string("lifted_name") + std::to_string(hash); } -bool IsClosure(const Function& func) { - return func->attrs.GetAttr(attr::kClosure, 0) != 0; -} +bool IsClosure(const Function& func) { return func->GetAttr(attr::kClosure, 0) != 0; } Function MarkClosure(Function func) { return WithAttr(std::move(func), attr::kClosure, tvm::Integer(1)); @@ -66,7 +64,7 @@ class LambdaLifter : public ExprMutator { auto pre_visit = [this](const LetNode* op) { bool is_lambda = false; if (auto func = op->value.as()) { - if (!func->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (!func->HasNonzeroAttr(attr::kPrimitive)) { is_lambda = true; this->letrec_.push_back(op->var); } @@ -106,7 +104,7 @@ class LambdaLifter : public ExprMutator { auto func = GetRef(func_node); // We should not transform primitive functions. - if (func->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (func->HasNonzeroAttr(attr::kPrimitive)) { return std::move(func); } @@ -226,7 +224,7 @@ class LambdaLifter : public ExprMutator { auto glob_funcs = module_->functions; for (auto pair : glob_funcs) { if (auto* n = pair.second.as()) { - if (n->attrs.GetAttr(attr::kCompiler).defined()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; auto func = GetRef(n); func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, func->attrs); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index a7f0959e169d..4a7974cae5ae 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -163,8 +163,8 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) } bool FunctionPassNode::SkipFunction(const Function& func) const { - return (func->attrs.GetAttr(attr::kCompiler).defined()) || - func->attrs.GetAttr(attr::kSkipOptimization, 0) != 0; + return (func->GetAttr(attr::kCompiler).defined()) || + func->GetAttr(attr::kSkipOptimization, 0) != 0; } Pass CreateFunctionPass( diff --git a/src/target/build_common.h b/src/target/build_common.h index 313ab5fa72bb..c66c2b52822e 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -50,18 +50,18 @@ inline std::unordered_map ExtractFuncInfo(co for (size_t i = 0; i < f->params.size(); ++i) { info.arg_types.push_back(f->params[i].dtype()); } - if (auto opt = f->attrs.GetAttr>(tir::attr::kDeviceThreadAxis)) { + if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { auto thread_axis = opt.value(); for (size_t i = 0; i < thread_axis.size(); ++i) { info.launch_param_tags.push_back(thread_axis[i]->thread_tag); } } - if (auto opt = f->attrs.GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { + if (auto opt = f->GetAttr(tir::attr::kDeviceUseDynSharedMemory)) { if (opt.value()) { info.launch_param_tags.push_back(runtime::kUseDynamicSharedMemoryTag); } } - auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); fmap[static_cast(global_symbol.value())] = info; } return fmap; diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 3e7c622a2f67..ab96d6e69d14 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -118,7 +118,7 @@ void CodeGenCPU::Init(const std::string& module_name, llvm::TargetMachine* tm, void CodeGenCPU::AddFunction(const PrimFunc& f) { CodeGenLLVM::AddFunction(f); if (f_tvm_register_system_symbol_ != nullptr) { - auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; export_system_symbols_.emplace_back( diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index f5ab55421ce9..9d324d56887f 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -47,7 +47,7 @@ namespace tvm { namespace codegen { static std::string get_name(const PrimFunc& f) { - auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; return std::string(global_symbol.value()); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index f87c5e65fdbd..6aabdc1bd804 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -116,7 +116,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { << "Cannot codegen function with buffer_map, please lower them first"; std::vector param_types; - is_restricted_ = f->attrs.HasNonzeroAttr(tir::attr::kNoAlias); + is_restricted_ = f->HasNonzeroAttr(tir::attr::kNoAlias); for (Var param : f->params) { param_types.push_back(GetLLVMType(param)); if (!is_restricted_ && param.dtype().is_handle()) { @@ -129,7 +129,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { llvm::FunctionType* ftype = llvm::FunctionType::get(ret_void ? t_void_ : t_int_, param_types, false); - auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; ICHECK(module_->getFunction(static_cast(global_symbol.value())) == nullptr) diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index defccc7cd4ac..12c7a3132947 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -234,10 +234,10 @@ class LLVMModuleNode final : public runtime::ModuleNode { continue; } auto f = Downcast(kv.second); - auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()); function_names_.push_back(global_symbol.value()); - if (f->attrs.HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { entry_func = global_symbol.value(); } funcs.push_back(f); diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 63fe53b52092..4a2917daa5ed 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -134,7 +134,7 @@ runtime::Module BuildCUDA(IRModule mod, Target target) { for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenCUDA: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto calling_conv = f->attrs.GetAttr(tvm::attr::kCallingConv); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index e2125ccc9a52..17e38e9af6e6 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -43,7 +43,7 @@ runtime::Module BuildAOCL(IRModule mod, Target target, bool emulation) { for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodegenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto calling_conv = f->attrs.GetAttr(tvm::attr::kCallingConv); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 4691f439b864..f676f0f598d8 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -77,10 +77,10 @@ void CodeGenC::AddFunction(const PrimFunc& f) { // reserve keywords ReserveKeywordsAsUnique(); - auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; - bool no_alias = f->attrs.HasNonzeroAttr(tir::attr::kNoAlias); + bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); this->PrintFuncPrefix(); this->PrintExtraAttrs(f); diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index eaed10c99a9c..dc849b8fa6b3 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -53,7 +53,7 @@ void CodeGenCHost::Init(bool output_ssa, bool emit_asserts, std::string target_s void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } void CodeGenCHost::AddFunction(const PrimFunc& f) { - auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; function_names_.push_back(global_symbol.value()); @@ -390,8 +390,7 @@ runtime::Module BuildCHost(IRModule mod, Target target) { // Make sure that the executor function is the last one to be code generated so that all the // symbols are available to tvm_run_func auto fun_name = std::string(kv.first->name_hint); - bool is_aot_executor_fn = - kv.second->attrs.GetAttr("runner_function", Bool(false)).value(); + bool is_aot_executor_fn = kv.second->GetAttr("runner_function", Bool(false)).value(); if (is_aot_executor_fn) { aot_executor_fn = Downcast(kv.second); diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 25e7a8dab72a..b44afec57d5d 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -58,7 +58,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { GetUniqueName("_"); // add to alloc buffer type. - auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; @@ -130,7 +130,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { ICHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); ICHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); int work_dim = 0; - auto thread_axis = f->attrs.GetAttr>(tir::attr::kDeviceThreadAxis).value(); + auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis).value(); for (IterVar iv : thread_axis) { runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag); @@ -340,7 +340,7 @@ runtime::Module BuildMetal(IRModule mod, Target target) { CodeGenMetal cg(target); cg.Init(output_ssa); auto f = Downcast(kv.second); - auto calling_conv = f->attrs.GetAttr(tvm::attr::kCallingConv); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index eb80d32017fc..edb614d9c122 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -292,7 +292,7 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) { CodeGenOpenCL cg; cg.Init(output_ssa); auto f = Downcast(kv.second); - auto calling_conv = f->attrs.GetAttr(tvm::attr::kCallingConv); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 8c2b6b25ed0b..9896d8b833f9 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -148,7 +148,7 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) { for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenVHLS: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto calling_conv = f->attrs.GetAttr(tvm::attr::kCallingConv); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); @@ -170,7 +170,7 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) { code = (*f)(code).operator std::string(); } - auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; kernel_info.push_back({global_symbol.value(), code}); diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 561103305bc4..e922942e8acf 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -113,10 +113,10 @@ runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenSPIRV: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto calling_conv = f->attrs.GetAttr(tvm::attr::kCallingConv); + auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 1f325684c230..42d0027a326f 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -41,8 +41,7 @@ CodeGenSPIRV::CodeGenSPIRV(Target target) : spirv_support_(target) {} runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::string& name) { this->InitFuncState(); - ICHECK(f->attrs.HasNonzeroAttr(tir::attr::kNoAlias)) - << "SPIRV only takes restricted memory model"; + ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; uint32_t i_buffer = 0; diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index c40e15bee92f..402e3291975f 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -516,7 +516,7 @@ runtime::Module BuildStackVM(IRModule mod, Target target) { for (auto kv : mod->functions) { ICHECK(kv.second->IsInstance()) << "CodeGenStackVM: Can only take PrimFunc"; auto f = Downcast(kv.second); - auto global_symbol = f->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute"; std::string f_name = global_symbol.value(); @@ -524,7 +524,7 @@ runtime::Module BuildStackVM(IRModule mod, Target target) { ICHECK(!fmap.count(f_name)) << "Function name " << f_name << "already exist in list"; fmap[f_name] = std::move(vm); - if (f->attrs.HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { entry_func = f_name; } } diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 14e5e7f9973a..2089ead98168 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -169,10 +169,10 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Interface of VerifyMemory pass std::vector VerifyMemory_(const PrimFunc& func) { - auto target = func->attrs.GetAttr(tvm::attr::kTarget); + auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "VerifyMemory: Require the target attribute"; - if (func->attrs.GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { MemoryAccessVerifier v(func, target.value()->kind->device_type); v.Run(); diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 5603247f6646..a41905c148bf 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -245,7 +245,7 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region } Bool IsFromLegacyTESchedule(PrimFunc f) { - Optional from_legacy_te_schedule = f->attrs.GetAttr("from_legacy_te_schedule", Bool(false)); + Optional from_legacy_te_schedule = f->GetAttr("from_legacy_te_schedule", Bool(false)); return from_legacy_te_schedule.value(); } diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 12b231c3bf2c..21f1b18d523b 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -207,7 +207,7 @@ namespace transform { Pass LowerCustomDatatypes() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - auto target = f->attrs.GetAttr(tvm::attr::kTarget); + auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute"; n->body = CustomDatatypesLowerer(target.value()->kind->name)(std::move(n->body)); diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 6414fc1819b8..2555002d29b0 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -297,10 +297,9 @@ namespace transform { Pass LowerIntrin() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - auto target = f->attrs.GetAttr(tvm::attr::kTarget); + auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; - // TODO(@electriclilies): This is most likely a problem.. auto mtriple = target.value()->GetAttr("mtriple", ""); n->body = IntrinInjecter(&analyzer, target.value()->kind->name, mtriple.value())(std::move(n->body)); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 572ce88ba012..481b1bfd4b19 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -600,7 +600,7 @@ namespace transform { Pass LowerThreadAllreduce() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - auto target = f->attrs.GetAttr(tvm::attr::kTarget); + auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; const TargetNode* target_node = target.as(); ThreadAllreduceBuilder thread_all_reduce(target_node); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 37ded00bfc0f..30ec148c37dd 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -387,7 +387,7 @@ namespace transform { Pass LowerWarpMemory() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); - auto target = f->attrs.GetAttr(tvm::attr::kTarget); + auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); WarpMemoryRewriter warp_memory_rewriter(warp_size); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 2f0c389c0373..393ce6c286b4 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -106,10 +106,10 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { } PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { - auto global_symbol = func->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; - auto target = func->attrs.GetAttr(tvm::attr::kTarget); + auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "MakePackedAPI: Require the target attribute"; int target_device_type = target.value()->kind->device_type; @@ -294,7 +294,7 @@ Pass MakePackedAPI(int num_unpacked_args) { for (const auto& kv : mptr->functions) { if (auto* n = kv.second.as()) { PrimFunc func = GetRef(n); - if (func->attrs.GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args); updates.push_back({kv.first, updated_func}); diff --git a/src/tir/transforms/make_unpacked_api.cc b/src/tir/transforms/make_unpacked_api.cc index d08f3320accb..6e8793fbd367 100644 --- a/src/tir/transforms/make_unpacked_api.cc +++ b/src/tir/transforms/make_unpacked_api.cc @@ -41,10 +41,10 @@ namespace tvm { namespace tir { PrimFunc MakeUnpackedAPI(PrimFunc&& func) { - auto global_symbol = func->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol) << "MakeUnpackedAPI: Expect PrimFunc to have the global_symbol attribute"; - auto target = func->attrs.GetAttr(tvm::attr::kTarget); + auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "MakeUnpackedAPI: Require the target attribute"; auto* func_ptr = func.CopyOnWrite(); @@ -111,7 +111,7 @@ Pass MakeUnpackedAPI() { for (const auto& kv : mptr->functions) { if (auto* n = kv.second.as()) { PrimFunc func = GetRef(n); - if (func->attrs.GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == + if (func->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == CallingConv::kDefault) { auto updated_func = MakeUnpackedAPI(std::move(func)); updates.push_back({kv.first, updated_func}); diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index 61867d18923e..e101e6b904ce 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -75,7 +75,7 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) tmap[kv.first] = kv.second; } - auto opt_thread_axis = f->attrs.GetAttr>(tir::attr::kDeviceThreadAxis); + auto opt_thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); ICHECK(opt_thread_axis != nullptr) << "Require attribute " << tir::attr::kDeviceThreadAxis; auto thread_axis = opt_thread_axis.value(); auto* n = f.CopyOnWrite(); diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 05be4c06ffaa..795ae9d6a73a 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -311,9 +311,9 @@ class HostDeviceSplitter : public StmtMutator { }; PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { - auto target = func->attrs.GetAttr(tvm::attr::kTarget); + auto target = func->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "SplitHostDevice: Require the target attribute"; - auto global_symbol = func->attrs.GetAttr(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 9e9dde43aa4d..2c32cc7f0883 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -501,8 +501,7 @@ class StorageFlattener : public StmtExprMutator { PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { // Only apply this pass to TIR from TE schedules - Optional from_legacy_te_schedule = - func->attrs.GetAttr("from_legacy_te_schedule", Bool(false)); + Optional from_legacy_te_schedule = func->GetAttr("from_legacy_te_schedule", Bool(false)); if (from_legacy_te_schedule.value()) { auto fptr = func.CopyOnWrite(); From dfe8f07965d992d0ec8058348a585bcd31ff5d58 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 16 Aug 2021 11:25:03 -0700 Subject: [PATCH 7/9] Didn't mean to revert these --- include/tvm/ir/attrs.h | 108 ++++++++++++++++++++++++++++++++++++++ include/tvm/ir/function.h | 103 ------------------------------------ include/tvm/ir/module.h | 2 + 3 files changed, 110 insertions(+), 103 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index da7bc12619bd..74c98da00189 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -214,6 +214,7 @@ class DictAttrsNode : public BaseAttrsNode { void VisitNonDefaultAttrs(AttrVisitor* v) final; void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final; Array ListFieldInfo() const final; + // type info static constexpr const char* _type_key = "DictAttrs"; TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); @@ -232,6 +233,72 @@ class DictAttrs : public Attrs { */ TVM_DLL explicit DictAttrs(Map dict); + // Utils for accessing attributes + // This needs to be on DictAttrs, not DictAttrsNode because we return the default + // value if DictAttrsNode is not defined. + /*! + * \brief Get a function attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const BaseFunc& f) { + * auto value = f->attrs.GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + static_assert(std::is_base_of::value, + "Can only call GetAttr with ObjectRef types."); + if (!defined()) return default_value; + const DictAttrsNode* node = this->as(); + + auto it = node->dict.find(attr_key); + if (it != node->dict.end()) { + return Downcast>((*it).second); + } else { + return default_value; + } + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + /*! + * \brief Check whether the function has an non-zero integer attr. + * + * This function can be used to check whether an optional + * attribute mark(e.g. inline) exists. + * + * \param attr_key The key to the attribute. + * \return The check result. + * + * \code + * + * void HasNonzeroAttrExample(const BaseFunc& f) { + * if (f->attrs.HasNonzeroAttr(attr::kInline)) { + * // inline the function. + * } + * } + * + * \endcode + */ + bool HasNonzeroAttr(const std::string& attr_key) const { + return GetAttr(attr_key, 0) != 0; + } + TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; @@ -249,6 +316,47 @@ inline TAttrs AttrsWithDefaultValues() { return TAttrs(n); } +/*! + * \brief Copy the function or module, but overrides + * the attribute value key with the value. + * + * \param input The thing to annotate (BaseFunc or IRModule) + * \param attr_key The attribute key. + * \param attr_value The value attribute value. + * + * \tparam TFunc The corresponding function or module type. + * + * \returns The new function or module with updated attributes. + * + * \note This function performs copy on write optimization for func and module. + * If we move a uniquely referenced func or module into WithAttr, + * then no additional copy will be performed. + * + * This is also why we make it as a function instead of a member function + * and why we pass by value in the first argument. + * + * \code + * + * // Recommended way to trigger copy on write + * func = WithAttr(std::move(func), "key1", value1); + * func = WithAttr(std::move(func), "key2", value2); + * + * \endcode + */ +template +inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) { + using TNode = typename TFunc::ContainerType; + static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); + TNode* node = input.CopyOnWrite(); + if (node->attrs.defined()) { + node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); + } else { + Map dict = {{attr_key, attr_value}}; + node->attrs = DictAttrs(dict); + } + return input; +} + // Namespace containing detail implementations namespace detail { using runtime::TVMArgValue; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 09c074cb71bd..020b3de77ab3 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -79,67 +79,6 @@ class BaseFuncNode : public RelayExprNode { /*! \brief Additional attributes storing the meta-data */ DictAttrs attrs; - /*! - * \brief Get a function attribute. - * - * \param attr_key The attribute key. - * \param default_value The default value if the key does not exist, defaults to nullptr. - * - * \return The result - * - * \tparam TOBjectRef the expected object type. - * \throw Error if the key exists but the value does not match TObjectRef - * - * \code - * - * void GetAttrExample(const BaseFunc& f) { - * auto value = f->GetAttr("AttrKey", 0); - * } - * - * \endcode - */ - template - Optional GetAttr( - const std::string& attr_key, - Optional default_value = Optional(nullptr)) const { - static_assert(std::is_base_of::value, - "Can only call GetAttr with ObjectRef types."); - if (!attrs.defined()) return default_value; - auto it = attrs->dict.find(attr_key); - if (it != attrs->dict.end()) { - return Downcast>((*it).second); - } else { - return default_value; - } - } - // variant that uses TObjectRef to enable implicit conversion to default value. - template - Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { - return GetAttr(attr_key, Optional(default_value)); - } - /*! - * \brief Check whether the function has an non-zero integer attr. - * - * This function can be used to check whether an optional - * attribute mark(e.g. inline) exists. - * - * \param attr_key The key to the attribute. - * \return The check result. - * - * \code - * - * void HasNonzeroAttrExample(const BaseFunc& f) { - * if (f->HasNonzeroAttr(attr::kInline)) { - * // inline the function. - * } - * } - * - * \endcode - */ - bool HasNonzeroAttr(const std::string& attr_key) const { - return GetAttr(attr_key, 0) != 0; - } - static constexpr const char* _type_key = "BaseFunc"; static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode); @@ -154,48 +93,6 @@ class BaseFunc : public RelayExpr { TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode); }; -/*! - * \brief Create a new function that copies func, but overrides - * the attribute value key with the value. - * - * \param func The input function. - * \param attr_key The attribute key. - * \param attr_value The value attribute value. - * - * \tparam TFunc The corresponding function type. - * - * \returns The new function with updated attributes. - * - * \note This function performs copy on write optimization for func. - * If we move a uniquely referenced func into WithAttr, - * then no additional copy will be performed. - * - * This is also why we make it as a function instead of a member function - * and why we pass by value in the first argument. - * - * \code - * - * // Recommended way to trigger copy on write - * func = WithAttr(std::move(func), "key1", value1); - * func = WithAttr(std::move(func), "key2", value2); - * - * \endcode - */ -template ::value>::type> -inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) { - using TNode = typename TFunc::ContainerType; - static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); - TNode* node = func.CopyOnWrite(); - if (node->attrs.defined()) { - node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); - } else { - Map dict = {{attr_key, attr_value}}; - node->attrs = DictAttrs(dict); - } - return func; -} - /*! * \brief Generic attribute names that can be attached to any function. * diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 638f132e3179..c88f438c7cd8 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -58,6 +58,8 @@ class IRModuleNode : public Object { Map type_definitions; /*! \brief The source map for the module. */ parser::SourceMap source_map; + /* \brief Additional attributes storing meta-data about the module. */ + DictAttrs attrs; IRModuleNode() : source_map() {} From 1e707d3bdde3a728ef8ac00f8cc23d4020d950ba Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 16 Aug 2021 11:26:51 -0700 Subject: [PATCH 8/9] Revert a few more things --- src/ir/module.cc | 2 +- src/relay/transforms/annotate_target.cc | 4 ++-- src/relay/transforms/fold_constant.cc | 2 +- src/relay/transforms/fuse_ops.cc | 4 ++-- src/relay/transforms/inline.cc | 6 +++--- src/relay/transforms/memory_alloc.cc | 4 ++-- src/relay/transforms/partial_eval.cc | 2 +- src/relay/transforms/partition_graph.cc | 4 ++-- src/relay/transforms/to_a_normal_form.cc | 4 ++-- src/relay/transforms/to_basic_block_normal_form.cc | 2 +- src/relay/transforms/to_cps.cc | 2 +- 11 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/ir/module.cc b/src/ir/module.cc index a62656aa69fb..7990b281fb04 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -356,7 +356,7 @@ IRModule IRModule::FromExpr(const RelayExpr& expr, if (auto* func_node = expr.as()) { func = GetRef(func_node); - if (auto opt = func->attrs.GetAttr(tvm::attr::kGlobalSymbol)) { + if (auto opt = func->GetAttr(tvm::attr::kGlobalSymbol)) { gv_name = opt.value(); } diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 19c96a161483..b12e25a425b6 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -227,7 +227,7 @@ class AnnotateTargetRewriter : public ExprRewriter { // if it is in the target list. Function func = Downcast(pre->op); ICHECK(func.defined()); - if (auto comp_name = func->attrs.GetAttr(attr::kComposite)) { + if (auto comp_name = func->GetAttr(attr::kComposite)) { std::string comp_name_str = comp_name.value(); size_t i = comp_name_str.find('.'); if (i != std::string::npos) { @@ -288,7 +288,7 @@ class AnnotateTargetRewriter : public ExprRewriter { Function func; Expr new_body; // don't step into composite functions - if (fn->attrs.GetAttr(attr::kComposite).defined()) { + if (fn->GetAttr(attr::kComposite).defined()) { func = GetRef(fn); new_body = func->body; } else { diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 2a3def468ab8..57603035b848 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -118,7 +118,7 @@ class ConstantFolder : public MixedModeMutator { bool inside_primitive = false; Expr VisitExpr_(const FunctionNode* op) final { - if (op->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (op->HasNonzeroAttr(attr::kPrimitive)) { ICHECK_EQ(inside_primitive, false); inside_primitive = true; auto ret = ExprMutator::VisitExpr_(op); diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index 1542e6d68d86..f1f7a95e33e8 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -199,7 +199,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { // Post order tree void VisitExpr_(const FunctionNode* op) final { // Skip the function that should be handled by external codegen. - if (op->attrs.GetAttr(attr::kCompiler).defined()) return; + if (op->GetAttr(attr::kCompiler).defined()) return; for (auto param : op->params) { this->Update(param, nullptr, kOpaque); @@ -856,7 +856,7 @@ class FuseMutator : private MixedModeMutator { // Skip primitive function. Expr VisitExpr_(const FunctionNode* fn_node) { - if (fn_node->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (fn_node->HasNonzeroAttr(attr::kPrimitive)) { return GetRef(fn_node); } else { return ExprMutator::VisitExpr_(fn_node); diff --git a/src/relay/transforms/inline.cc b/src/relay/transforms/inline.cc index a8ce7bd22bbf..6e6505b28dc6 100644 --- a/src/relay/transforms/inline.cc +++ b/src/relay/transforms/inline.cc @@ -99,7 +99,7 @@ class Inliner : ExprMutator { if (!func->body.defined()) return false; // The function must be annotated with the inline attribute. - if (!func->attrs.HasNonzeroAttr(attr::kInline)) return false; + if (!func->HasNonzeroAttr(attr::kInline)) return false; // The function is not abled to be inlined if any callee under the CallGraph // of this function cannot be inlined. @@ -122,7 +122,7 @@ class Inliner : ExprMutator { auto func = Function(fn->params, fn->body, fn->ret_type, fn->type_params, fn->attrs); // Inline the function body to the caller if this function uses default // compiler, i.e. no external codegen is needed. - if (!func->attrs.GetAttr(attr::kCompiler).defined()) { + if (!func->GetAttr(attr::kCompiler).defined()) { ICHECK_EQ(func->params.size(), args.size()) << "Mismatch found in the number of parameters and call args"; // Bind the parameters with call args. @@ -188,7 +188,7 @@ IRModule Inline(const IRModule& module) { auto base_func = cg->GetGlobalFunction(cgn->GetGlobalVar()); if (const auto* fn = base_func.as()) { auto func = GetRef(fn); - if (func->attrs.HasNonzeroAttr(attr::kInline)) { + if (func->HasNonzeroAttr(attr::kInline)) { ICHECK_EQ(cgn->GetRefCount(), 0U) << cgn->GetNameHint() << " is marked as inline but not inlined."; cgn->CleanCallGraphEntries(); diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 7717c741951a..657e2c392455 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -71,7 +71,7 @@ inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dt // Check if the primitive function contains only reshape ops. bool IsReshapeOnly(const Expr& expr) { if (const FunctionNode* func = expr.as()) { - return func->attrs.HasNonzeroAttr(attr::kReshapeOnly); + return func->HasNonzeroAttr(attr::kReshapeOnly); } if (const CallNode* call = expr.as()) { if (call->attrs.defined()) { @@ -199,7 +199,7 @@ class DialectRewriter : public ExprMutator { // Check if a call invokes a primitive function. bool IsPrimitive(const CallNode* call) const { if (const auto* fn = call->op.as()) { - return fn->attrs.HasNonzeroAttr(attr::kPrimitive); + return fn->HasNonzeroAttr(attr::kPrimitive); } return false; } diff --git a/src/relay/transforms/partial_eval.cc b/src/relay/transforms/partial_eval.cc index fb49d69855f0..9572faf08714 100644 --- a/src/relay/transforms/partial_eval.cc +++ b/src/relay/transforms/partial_eval.cc @@ -776,7 +776,7 @@ class PartialEvaluator : public ExprFunctor Func VisitFuncStatic(const Function& func, const Expr& var) { ICHECK(IsAtomic(var)); - if (func->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (func->HasNonzeroAttr(attr::kPrimitive)) { return ConstEvaluateFunc(func); } std::vector > free_vars; diff --git a/src/relay/transforms/partition_graph.cc b/src/relay/transforms/partition_graph.cc index d67da140a803..b48fbe44bd11 100644 --- a/src/relay/transforms/partition_graph.cc +++ b/src/relay/transforms/partition_graph.cc @@ -499,7 +499,7 @@ class NameMangleExtFuncs : public MixedModeMutator { for (const auto& pair : glob_funcs) { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); - if (func->attrs.GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { auto fn_name_mangled = mangle_fn_(pair.first->name_hint); GlobalVar gvar = GlobalVar(fn_name_mangled); mangled_gvars_[pair.first->name_hint] = gvar; @@ -514,7 +514,7 @@ class NameMangleExtFuncs : public MixedModeMutator { if (auto* fn = pair.second.as()) { auto func = GetRef(fn); - if (func->attrs.GetAttr(attr::kCompiler).defined()) { + if (func->GetAttr(attr::kCompiler).defined()) { auto new_dict = func->attrs->dict; new_dict.Set(tvm::attr::kGlobalSymbol, String(mangle_fn_(pair.first->name_hint))); func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params, diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 189165131b8e..91e8d90c1232 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -200,7 +200,7 @@ Expr Fill::VisitExpr_(const IfNode* i, const Var& v) { Expr Fill::VisitExpr_(const FunctionNode* f, const Var& v) { Expr e = GetRef(f); Expr ret; - if (f->attrs.HasNonzeroAttr(attr::kPrimitive)) { + if (f->HasNonzeroAttr(attr::kPrimitive)) { ret = e; } else { ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type, @@ -260,7 +260,7 @@ IRModule ToANormalForm(const IRModule& m) { for (const auto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0); if (const auto* n = it.second.as()) { - if (n->attrs.GetAttr(attr::kCompiler).defined()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; } Expr ret = TransformF([&](const Expr& e) { return transform::ToANormalForm(e); }, it.second); ICHECK_EQ(FreeVars(ret).size(), 0) diff --git a/src/relay/transforms/to_basic_block_normal_form.cc b/src/relay/transforms/to_basic_block_normal_form.cc index fbab6a91b824..79157bba1918 100644 --- a/src/relay/transforms/to_basic_block_normal_form.cc +++ b/src/relay/transforms/to_basic_block_normal_form.cc @@ -56,7 +56,7 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) { for (const auto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables"; if (const auto* n = it.second.as()) { - if (n->attrs.GetAttr(attr::kCompiler).defined()) continue; + if (n->GetAttr(attr::kCompiler).defined()) continue; } Expr ret = TransformF([&](const Expr& e) { return ToBasicBlockNormalFormAux(e); }, it.second); updates.Set(it.first, Downcast(ret)); diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index 5dc0198d9872..b7f9cafbc7dc 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -134,7 +134,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, } Expr VisitExpr_(const FunctionNode* op, const MCont& k) final { - ICHECK(!op->attrs.HasNonzeroAttr(attr::kPrimitive)) << "primitive func not supported yet."; + ICHECK(!op->HasNonzeroAttr(attr::kPrimitive)) << "primitive func not supported yet."; return k(ToCPS(GetRef(op), m, cm, vm, answer)); } From 92641594bb33763803d5412572bca6ed03988943 Mon Sep 17 00:00:00 2001 From: electriclilies Date: Mon, 16 Aug 2021 12:40:04 -0700 Subject: [PATCH 9/9] Add GetAttrs to IRModuleNode --- include/tvm/ir/attrs.h | 2 +- include/tvm/ir/function.h | 52 +++++++++++++++++++++++++++++++++++++++ include/tvm/ir/module.h | 52 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 1 deletion(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 74c98da00189..fa1861051e2f 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -288,7 +288,7 @@ class DictAttrs : public Attrs { * \code * * void HasNonzeroAttrExample(const BaseFunc& f) { - * if (f->attrs.HasNonzeroAttr(attr::kInline)) { + * if (f->HasNonzeroAttr(attr::kInline)) { * // inline the function. * } * } diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 020b3de77ab3..13b984d9cb35 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -79,6 +79,58 @@ class BaseFuncNode : public RelayExprNode { /*! \brief Additional attributes storing the meta-data */ DictAttrs attrs; + /*! + * \brief Get a function attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const BaseFunc& f) { + * auto value = f->GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + return attrs.GetAttr(attr_key, default_value); + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + + /*! + * \brief Check whether the function has an non-zero integer attr. + * + * This function can be used to check whether an optional + * attribute mark(e.g. inline) exists. + * + * \param attr_key The key to the attribute. + * \return The check result. + * + * \code + * + * void HasNonzeroAttrExample(const BaseFunc& f) { + * if (f->HasNonzeroAttr(attr::kInline)) { + * // inline the function. + * } + * } + * + * \endcode + */ + bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); } + static constexpr const char* _type_key = "BaseFunc"; static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode); diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index c88f438c7cd8..9ca27ec3b661 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -61,6 +61,58 @@ class IRModuleNode : public Object { /* \brief Additional attributes storing meta-data about the module. */ DictAttrs attrs; + /*! + * \brief Get a module attribute. + * + * \param attr_key The attribute key. + * \param default_value The default value if the key does not exist, defaults to nullptr. + * + * \return The result + * + * \tparam TOBjectRef the expected object type. + * \throw Error if the key exists but the value does not match TObjectRef + * + * \code + * + * void GetAttrExample(const IRModule& mod) { + * auto value = f->GetAttr("AttrKey", 0); + * } + * + * \endcode + */ + template + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { + return attrs.GetAttr(attr_key, default_value); + } + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr(const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } + + /*! + * \brief Check whether the module has an non-zero integer attr. + * + * This function can be used to check whether an optional + * attribute mark(e.g. inline) exists. + * + * \param attr_key The key to the attribute. + * \return The check result. + * + * \code + * + * void HasNonzeroAttrExample(const IRModule& mod) { + * if (mod->HasNonzeroAttr(attr::kInline)) { + * // inline the function. + * } + * } + * + * \endcode + */ + bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); } + IRModuleNode() : source_map() {} void VisitAttrs(AttrVisitor* v) {