From 1c54e3f189079c68acad7f67dcd3e90c613c971b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 19 Mar 2024 08:18:49 -0500 Subject: [PATCH 1/7] [IR] Default to empty attributes, instead of NULL Prior to this commit, the default `DictAttrs` for an `IRModule`, `tir::PrimFunc`, `relax::Function`, and `relay::Function` was a null value. At each callsite, the absence of a `DictAttrs` needed to be treated as equivalent to an empty `DictAttrs`. In C++, this typically was done using the `foo->GetAttr` helper function, but in Python it needed to be checked explicitly. That is, every callsite needed to check `if func.attrs is not None and attr_name in func.attrs`, rather than only checking `if attr_name in func.attrs`. Since most functions would have at least one attribute to specify the global symbol, these bugs would often surface when working on unrelated changes. This commit changes the default attribute dictionary from `NullValue()` to `DictAttrs()`. This avoids having two separate representations of an object without any attributes, and allows the `if attr_name in func.attrs` pattern in the Python API. --- include/tvm/ir/attrs.h | 7 ++---- include/tvm/ir/module.h | 3 ++- include/tvm/relax/expr.h | 5 ++--- include/tvm/relay/function.h | 2 +- include/tvm/runtime/object.h | 14 ++++++++++-- include/tvm/script/ir_builder/tir/frame.h | 2 +- include/tvm/tir/function.h | 2 +- src/relay/analysis/type_solver.cc | 2 +- src/relay/backend/vm/lambda_lift.cc | 2 +- src/relay/ir/dataflow_matcher.cc | 2 +- src/relay/ir/function.cc | 2 +- src/relay/transforms/dynamic_to_static.cc | 3 +-- src/relay/transforms/to_cps.cc | 4 ++-- src/script/ir_builder/ir/frame.cc | 2 +- src/script/ir_builder/tir/frame.cc | 14 ++---------- src/script/ir_builder/tir/ir.cc | 27 ++++++++++++++++------- 16 files changed, 50 insertions(+), 43 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 18d0f025c776..81611b1a535a 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -230,7 +230,7 @@ class DictAttrs : public Attrs { * \brief Consruct a Attrs backed by DictAttrsNode. * \param dict The attributes. */ - TVM_DLL explicit DictAttrs(Map dict); + TVM_DLL explicit DictAttrs(Map dict = {}); // Utils for accessing attributes // This needs to be on DictAttrs, not DictAttrsNode because we return the default @@ -298,7 +298,7 @@ class DictAttrs : public Attrs { return GetAttr(attr_key, 0).value_or(0).IntValue() != 0; } - TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode); + TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(DictAttrs, Attrs, DictAttrsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; @@ -415,9 +415,6 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { if (input->attrs.defined()) { TNode* node = input.CopyOnWrite(); node->attrs.CopyOnWrite()->dict.erase(attr_key); - if (node->attrs->dict.size() == 0) { - node->attrs = NullValue(); - } } return input; } diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index ad6efa529cc2..2a5412a5671f 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -376,7 +376,8 @@ class IRModule : public ObjectRef { TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, std::unordered_set import_set = {}, SourceMap map = {}, - DictAttrs attrs = {}, Map> global_infos = {}); + DictAttrs attrs = DictAttrs(), + Map> global_infos = {}); /*! \brief default constructor */ IRModule() : IRModule(Map({})) {} diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index fdbd7bd8eb2c..4634d1e228d3 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -983,15 +983,14 @@ class FunctionNode : public BaseFuncNode { class Function : public BaseFunc { public: TVM_DLL explicit Function(Array params, Expr body, Optional ret_struct_info, - bool is_pure = true, DictAttrs attrs = NullValue(), - Span span = Span()); + bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span()); /*! * \brief Mimics the constructor but without body Expr. * \note ret_struct_info is required, since it can not deduced by the body. */ TVM_DLL static Function CreateEmpty(Array params, StructInfo ret_struct_info, - bool is_pure = true, DictAttrs attrs = NullValue(), + bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 874d4f233416..798f6d4d2566 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -114,7 +114,7 @@ class Function : public BaseFunc { * \param span The span of the function. */ TVM_DLL Function(tvm::Array params, Expr body, Type ret_type, tvm::Array ty_params, - tvm::DictAttrs attrs = NullValue(), Span span = Span()); + tvm::DictAttrs attrs = DictAttrs(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 92f477b058fd..172316daae59 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -741,14 +741,24 @@ struct ObjectPtrEqual { * \param ParentType The parent type of the objectref * \param ObjectName The type name of the object. */ -#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ +#define TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, \ + ObjectName) \ explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ const ObjectName* operator->() const { return static_cast(data_.get()); } \ const ObjectName* get() const { return operator->(); } \ using ContainerType = ObjectName; +/* + * \brief Define object reference methods. + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, ObjectName) + /* * \brief Define object reference methods that is not nullable. * diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 0cc385d876a8..598750f0ac48 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -78,7 +78,7 @@ class PrimFuncFrameNode : public TIRFrameNode { /*! \brief Maps some parameters to specific Buffer data structures. */ Map buffer_map; /*! \brief Additional attributes storing the meta-data */ - Optional> attrs; + Map attrs; /*! \brief The variable map bound to thread env. */ Map env_threads; /*! \brief The buffer allocated in root block. */ diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 1917a3c22c6e..274ebd0a6558 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -164,7 +164,7 @@ class PrimFunc : public BaseFunc { */ TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), Map buffer_map = Map(), - DictAttrs attrs = NullValue(), Span span = Span()); + DictAttrs attrs = DictAttrs(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 5bd5698d8321..c4fab210acb8 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -659,7 +659,7 @@ TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver") auto module = IRModule({}, {}); DiagnosticContext diag_ctx = DiagnosticContext::Default(module); auto dummy_fn_name = GlobalVar("test"); - module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array({})), Type(), {}, {})); + module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array({})), Type(), {})); auto solver = std::make_shared(dummy_fn_name, diag_ctx); auto mod = [module, solver, diag_ctx](std::string name) -> PackedFunc { diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index ba94e4b19ec7..48449eb02149 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -194,7 +194,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { CHECK_EQ(before_arity, after_arity); lifted_func = Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(), - free_type_vars, /*attrs=*/{}, func->span); + free_type_vars, DictAttrs(), func->span); lifted_func->virtual_device_ = result_virtual_device; lifted_func = MarkClosure(lifted_func); } diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index ee585446cb26..8e756a8aa2d3 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -438,7 +438,7 @@ Expr InferTypeWithModule(const Expr& expr, const IRModule& m) { if (expr.as()) { func = Downcast(expr); } else { - func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); + func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod)); } mod->Add(gvar, func); mod = transform::InferType()(mod); diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index fd8c646ecf1c..8d2477f109a4 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -251,7 +251,7 @@ TVM_REGISTER_GLOBAL("relay.ir.IRModuleUpdateWithRenamer") TVM_REGISTER_GLOBAL("relay.ir.FunctionFromExprInContext") .set_body_typed([](RelayExpr expr, IRModule mod) -> Function { - return Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); + return Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod)); }); TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttr") diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index a989cf53f818..c192097a0b29 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -253,8 +253,7 @@ class DynamicToStaticMutator : public MixedModeMutator { if (auto func_node = expr.as()) { func = func_node.value(); } else { - func = - relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod_), {}); + func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod_)); } mod_->Update(gv_, func); diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index 7c90d101b567..05d49cf5047c 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -170,7 +170,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, Expr reify(const MCont& k) { Var arg = Var("arg", Type()); - return Function({arg}, k(arg), Type(), {}, {}); + return Function({arg}, k(arg), Type(), {}); } Expr reify(const MCont& k, const std::function& cont) { @@ -328,7 +328,7 @@ Function UnCPS(const Function& f) { // TODO(@M.K.): make alphaequal work on free term // ICHECK(tvm::StructuralEqual()(cont_type, Arrow(new_ret_type, answer_type))); auto x = Var("x", new_ret_type); - auto cont = Function({x}, x, new_ret_type, {}, {}); + auto cont = Function({x}, x, new_ret_type, {}); tvm::Array args; for (const auto& p : new_params) { args.push_back(p); diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index 3d917cee887b..60a35ee010ec 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -38,7 +38,7 @@ void IRModuleFrameNode::ExitWithScope() { } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; - auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); + auto dict_attrs = DictAttrs(attrs); builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs, global_infos); } diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index c15a290bf03d..42adebf2d848 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -33,17 +33,7 @@ void PrimFuncFrameNode::ExitWithScope() { // if the prim func is not private and there isn't already a global symbol, // add a global symbol if (!is_private && name.defined()) { - if (!attrs.defined()) { - attrs = {{tvm::attr::kGlobalSymbol, name.value()}}; - } else if (!attrs.value().count(tvm::attr::kGlobalSymbol)) { - // copy over attributes (can't mutate the dict inside the optional in-place) - Map new_attrs; - for (auto kv : attrs.value()) { - new_attrs.Set(kv.first, kv.second); - } - new_attrs.Set(tvm::attr::kGlobalSymbol, name.value()); - attrs = std::move(new_attrs); - } + attrs.Set(tvm::attr::kGlobalSymbol, name.value()); } tvm::tir::PrimFunc func( @@ -51,7 +41,7 @@ void PrimFuncFrameNode::ExitWithScope() { /*body=*/AsStmt(stmts), /*ret_type=*/ret_type.value_or(TupleType::Empty()), /*buffer_map=*/buffer_map, - /*attrs=*/attrs.defined() ? DictAttrs(attrs.value()) : NullValue()); + /*attrs=*/DictAttrs(attrs)); func = tvm::tir::ScriptComplete(func, root_alloc_buffers); IRBuilder builder = IRBuilder::Current(); if (builder->frames.empty()) { diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index cf73ffa0eedd..aa8b7131ef2b 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -61,7 +61,7 @@ PrimFuncFrame PrimFunc(bool is_private) { n->args.clear(); n->ret_type = NullOpt; n->buffer_map.clear(); - n->attrs = NullOpt; + n->attrs = {}; n->env_threads.clear(); n->root_alloc_buffers.clear(); return PrimFuncFrame(n); @@ -91,16 +91,27 @@ void FuncName(String name) { frame->name = name; } -void FuncAttrs(Map attrs) { +void FuncAttrs(Map new_attrs) { using namespace tvm::tir; PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr"); - if (frame->attrs.defined()) { - LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs; - } - if (attrs.count(tvm::attr::kGlobalSymbol) && frame->is_private) { - LOG(FATAL) << "ValueError: Specifying the global symbol even though the PrimFunc is annotated " - "as private"; + auto attrs = frame->attrs; + for (const auto& [key, value] : new_attrs) { + if (key == tvm::attr::kGlobalSymbol && frame->is_private) { + LOG(FATAL) << "ValueError: " + << "A private function may not have the kGlobalSymbol (\"" + << tvm::attr::kGlobalSymbol << "\") attribute. " + << "However, a private function specified the global symbol as " << value; + } + + if (auto prev = attrs.Get(key)) { + LOG(FATAL) << "ValueError: " + << "Duplicate prim func annotation for key = \"" << key << "\". " + << "Previous value was " << prev.value() << ", with later definition as " << value; + } else { + frame->attrs.Set(key, value); + } } + frame->attrs = attrs; } From afc6baf6a36effe8fddccd5dcaca80d9e27ecd1e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 19 Mar 2024 08:36:37 -0500 Subject: [PATCH 2/7] Remove no-longer-needed checks on attrs being present --- python/tvm/contrib/cutlass/build.py | 2 +- python/tvm/contrib/relay_viz/interface.py | 23 +++++++++---------- python/tvm/dlight/base/transform.py | 2 -- python/tvm/dlight/gpu/matmul.py | 4 ++-- python/tvm/driver/build_module.py | 2 +- python/tvm/meta_schedule/relax_integration.py | 2 +- python/tvm/relax/backend/contrib/cutlass.py | 4 ++-- python/tvm/relax/frontend/common.py | 2 +- python/tvm/relax/training/setup_trainer.py | 12 ++++------ .../relax/transform/lazy_transform_params.py | 4 ++-- .../tvm/relay/backend/contrib/ethosu/util.py | 2 +- .../relay/quantize/_partition_conversions.py | 4 ++-- python/tvm/relay/testing/py_converter.py | 2 +- python/tvm/relay/transform/recast.py | 2 +- python/tvm/testing/aot.py | 2 +- tests/python/contrib/test_coreml_codegen.py | 2 +- tests/python/contrib/test_dnnl.py | 2 +- .../test_meta_schedule_cpu_dot_product.py | 2 +- tests/python/relax/test_codegen_cutlass.py | 6 ++--- tests/python/tir-base/test_tir_nodes.py | 2 +- .../test_tir_transform_helpers.py | 20 ++++++++-------- 21 files changed, 48 insertions(+), 55 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 80169f51640e..59803f20feb5 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -977,7 +977,7 @@ def handle_norm(self, f, op_type): return f.with_attrs(attrs) def visit_function_(self, f): - if f.attrs is None or "Composite" not in f.attrs: + if b"Composite" not in f.attrs: body = super().visit_expr(f.body) return relax.Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span) diff --git a/python/tvm/contrib/relay_viz/interface.py b/python/tvm/contrib/relay_viz/interface.py index 15dbbf9fd6b6..8df188fcf42e 100644 --- a/python/tvm/contrib/relay_viz/interface.py +++ b/python/tvm/contrib/relay_viz/interface.py @@ -213,14 +213,14 @@ def _function( node_to_id: Dict[relay.Expr, str], ) -> Tuple[Union[VizNode, None], List[VizEdge]]: """Render rule for a relay function node""" - node_details = [] - name = "" func_attrs = node.attrs - if func_attrs: - node_details = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] - # "Composite" might from relay.transform.MergeComposite - if "Composite" in func_attrs.keys(): - name = func_attrs["Composite"] + node_details = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] + # "Composite" might from relay.transform.MergeComposite + if "Composite" in func_attrs.keys(): + name = func_attrs["Composite"] + else: + name = "" + node_id = node_to_id[node] # Body -> FunctionNode @@ -244,11 +244,10 @@ def _call( elif isinstance(node.op, relay.Function): func_attrs = node.op.attrs op_name = "Anonymous Func" - if func_attrs: - node_detail = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] - # "Composite" might from relay.transform.MergeComposite - if "Composite" in func_attrs.keys(): - op_name = func_attrs["Composite"] + node_detail = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] + # "Composite" might from relay.transform.MergeComposite + if "Composite" in func_attrs.keys(): + op_name = func_attrs["Composite"] elif isinstance(node.op, relay.GlobalVar): op_name = "GlobalVar" node_detail = [f"GlobalVar.name_hint: {node.op.name_hint}"] diff --git a/python/tvm/dlight/base/transform.py b/python/tvm/dlight/base/transform.py index 89ecaa6350fb..d697e9440b31 100644 --- a/python/tvm/dlight/base/transform.py +++ b/python/tvm/dlight/base/transform.py @@ -31,8 +31,6 @@ def _is_scheduled(func: tir.PrimFunc) -> bool: if not isinstance(func, tir.PrimFunc): return False - if not func.attrs: - return False if "tir.is_scheduled" not in func.attrs: return False return func.attrs["tir.is_scheduled"] == 1 diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index 9318b9149245..0f224b89f9e4 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -335,7 +335,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring root_block = analysis.get_root_block(sch) blocks = sch.get_child_blocks(root_block) - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + if "dlight.do_not_tensorize" in func.attrs.keys(): return None reduction_blocks = get_reduction_blocks(sch, blocks) @@ -556,7 +556,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring root_block = analysis.get_root_block(sch) blocks = sch.get_child_blocks(root_block) - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + if "dlight.do_not_tensorize" in func.attrs.keys(): return None reduction_blocks = get_reduction_blocks(sch, blocks) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index e23765e92d8c..c332062b37b9 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -249,7 +249,7 @@ def build( if target is None and isinstance(input_mod, tvm.IRModule): target_mod = {} for gvar, func in input_mod.functions.items(): - tgt = func.attrs["target"] if func.attrs and "target" in func.attrs else "llvm" + tgt = func.attrs["target"] if "target" in func.attrs else "llvm" if tgt not in target_mod: target_mod[tgt] = {} target_mod[tgt][gvar] = func diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py index 57daeea2d97b..c3c24aa631d6 100644 --- a/python/tvm/meta_schedule/relax_integration.py +++ b/python/tvm/meta_schedule/relax_integration.py @@ -138,7 +138,7 @@ def extracted_tasks_to_tune_contexts( get_loggers_from_work_dir(work_dir, [t.task_name for t in extracted_tasks]), fork_seed(seed, n=len(extracted_tasks)), ): - if task.mod.attrs is not None and task.mod.attrs.get("tir.is_scheduled", False): + if task.mod.attrs.get("tir.is_scheduled", False): warnings.warn("The task {task.task_name} is already scheduled, skipping it.") continue tasks.append( diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index a611bee2bbcd..0d9f4ff8e923 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -526,11 +526,11 @@ def __init__(self, mod): super().__init__(mod) def visit_function_(self, f): - if f.attrs is None or "Composite" not in f.attrs: + if "Composite" not in f.attrs: body = super().visit_expr(f.body) new_f = Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span) - if f.attrs and "global_symbol" in f.attrs and "cutlass" in f.attrs["global_symbol"]: + if "global_symbol" in f.attrs and "cutlass" in f.attrs["global_symbol"]: composite_func = body.blocks[0].bindings[0].value if "WorkspaceSize" in composite_func.attrs: return new_f.with_attr("WorkspaceSize", composite_func.attrs["WorkspaceSize"]) diff --git a/python/tvm/relax/frontend/common.py b/python/tvm/relax/frontend/common.py index cc36bbbc72ba..bbd0c55aac2e 100644 --- a/python/tvm/relax/frontend/common.py +++ b/python/tvm/relax/frontend/common.py @@ -42,7 +42,7 @@ def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.n detached_mod = tvm.IRModule() params_dict = dict() for gv, func in mod.functions_items(): - if func.attrs is not None and "params" in func.attrs: + if "params" in func.attrs: params = list(func.attrs["params"]) if not all([isinstance(param, tvm.nd.NDArray) for param in params]): raise ValueError( diff --git a/python/tvm/relax/training/setup_trainer.py b/python/tvm/relax/training/setup_trainer.py index 2e2057086904..71bf8509a63e 100644 --- a/python/tvm/relax/training/setup_trainer.py +++ b/python/tvm/relax/training/setup_trainer.py @@ -138,19 +138,15 @@ def _check_well_formed(self, mod: IRModule): ) from exc # Check function attrs - if ( - mod.attrs is None - or not self.PARAM_NUM_ATTR_KEY in mod.attrs - or not isinstance(mod.attrs[self.PARAM_NUM_ATTR_KEY], IntImm) + if not self.PARAM_NUM_ATTR_KEY in mod.attrs or not isinstance( + mod.attrs[self.PARAM_NUM_ATTR_KEY], IntImm ): raise ValueError( f"SetupTrainer: The backbone module should has an integer attribute named " f"{self.PARAM_NUM_ATTR_KEY}" ) - if ( - mod.attrs is None - or not self.STATE_NUM_ATTR_KEY in mod.attrs - or not isinstance(mod.attrs[self.STATE_NUM_ATTR_KEY], IntImm) + if not self.STATE_NUM_ATTR_KEY in mod.attrs or not isinstance( + mod.attrs[self.STATE_NUM_ATTR_KEY], IntImm ): raise ValueError( f"SetupTrainer: The backbone module should has an integer attribute named " diff --git a/python/tvm/relax/transform/lazy_transform_params.py b/python/tvm/relax/transform/lazy_transform_params.py index e8e8229965c5..1b025f7d3a6a 100644 --- a/python/tvm/relax/transform/lazy_transform_params.py +++ b/python/tvm/relax/transform/lazy_transform_params.py @@ -138,7 +138,7 @@ def __init__( self.memory_free_insertion = None def transform(self, func: relax.Function) -> relax.Function: - if func.attrs is not None and "num_input" in func.attrs: + if "num_input" in func.attrs: num_input = func.attrs["num_input"].value else: num_input = 0 @@ -235,7 +235,7 @@ def __init__(self, func_creator, mod: Optional[IRModule] = None) -> None: super().__init__(mod) def visit_function_(self, func: relax.Function) -> relax.Expr: - if func.attrs is not None and "num_input" in func.attrs: + if "num_input" in func.attrs: num_input = func.attrs["num_input"].value else: num_input = 0 diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 289754d5c370..a402604b4c11 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -156,7 +156,7 @@ class QPadArgs(Enum): def is_npu_func(func: relay.Function) -> bool: """Check if the given function is an NPU function.""" - return func.attrs and "Compiler" in func.attrs and func.attrs["Compiler"] == "ethos-u" + return "Compiler" in func.attrs and func.attrs["Compiler"] == "ethos-u" def is_composite_func(func: relay.Function, name: str) -> bool: diff --git a/python/tvm/relay/quantize/_partition_conversions.py b/python/tvm/relay/quantize/_partition_conversions.py index 8ba5c9ae2f20..8fec69cdf53e 100644 --- a/python/tvm/relay/quantize/_partition_conversions.py +++ b/python/tvm/relay/quantize/_partition_conversions.py @@ -215,7 +215,7 @@ def partition_prefix(mod, quantized_dtypes): prefix_cutter = PrefixCutter(func.params, quantized_dtypes) mid_body = prefix_cutter.visit(func.body) assert not func.type_params, "unimplemented" - assert func.attrs is None, "unimplemented" + assert not func.attrs, "unimplemented" mid_func = relay.Function(relay.analysis.free_vars(mid_body), mid_body) mid_mod = tvm.IRModule.from_expr(mid_func) mid_mod = relay.transform.InferType()(mid_mod) @@ -288,7 +288,7 @@ def partition_suffix(mod, quantized_dtypes): suffix_cutter = SuffixCutter(quantized_dtypes) post_body = suffix_cutter.visit(func.body) assert not func.type_params, "unimplemented" - assert func.attrs is None, "unimplemented" + assert not func.attrs, "unimplemented" post_func = relay.Function(relay.analysis.free_vars(post_body), post_body, func.ret_type) post_mod = tvm.IRModule.from_expr(post_func) post_mod = relay.transform.InferType()(post_mod) diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index 9cbfcead4783..799a106ea613 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -553,7 +553,7 @@ def visit_call(self, call: Expr): # lowered operator: generate a call to a function that gets the PackedFunc # from TVM's registry - if isinstance(func, Function) and func.attrs and func.attrs.Primitive.value == 1: + if isinstance(func, Function) and func.attrs.Primitive.value == 1: op_call_def, op_call = self.create_op_call(func, call.args, fields) return (op_call, field_defs + [op_call_def]) diff --git a/python/tvm/relay/transform/recast.py b/python/tvm/relay/transform/recast.py index 39f07b2eb926..36824fb93103 100644 --- a/python/tvm/relay/transform/recast.py +++ b/python/tvm/relay/transform/recast.py @@ -75,7 +75,7 @@ def visit_call(self, call): # If out_dtype is in the attributes, we need to update it. orig_dtype = None - if call.attrs is not None and "out_dtype" in call.attrs.keys(): + if "out_dtype" in call.attrs.keys(): new_attr_dict = {} for attr in call.attrs.keys(): attr_value = call.attrs[attr] diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index 8d74f545a3c2..f6d99090510e 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -1001,7 +1001,7 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"): main = mod else: main = mod["main"] - if main.attrs is None or main.attrs["output_tensor_names"] is None: + if main.attrs["output_tensor_names"] is None: output_tensor_names = ( ["output"] if output_count == 1 else [f"output{i}" for i in range(output_count)] ) diff --git a/tests/python/contrib/test_coreml_codegen.py b/tests/python/contrib/test_coreml_codegen.py index 2edfafaa0bd8..f0cdf14aa019 100644 --- a/tests/python/contrib/test_coreml_codegen.py +++ b/tests/python/contrib/test_coreml_codegen.py @@ -140,7 +140,7 @@ def _construct_model(func, m1, m2): fcompile = tvm._ffi.get_global_func("relay.ext.coremlcompiler") for var, func in mod.functions.items(): - if func.attrs and "Compiler" in func.attrs and func.attrs["Compiler"] == "coremlcompiler": + if "Compiler" in func.attrs and func.attrs["Compiler"] == "coremlcompiler": fcompile(func) diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index c45149fc5f1e..27ff7b8a38dd 100644 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -1374,7 +1374,7 @@ def _visit(node): op = node.op if isinstance(op, relay.GlobalVar): func = mod[op] - if "Compiler" in func.attrs and func.attrs["Compiler"] == desired_compiler: + if "Compiler" in func.attrs["Compiler"] == desired_compiler: matched_ops.append(op) return else: diff --git a/tests/python/meta_schedule/test_meta_schedule_cpu_dot_product.py b/tests/python/meta_schedule/test_meta_schedule_cpu_dot_product.py index 592c772a04dd..cc2731ff5974 100644 --- a/tests/python/meta_schedule/test_meta_schedule_cpu_dot_product.py +++ b/tests/python/meta_schedule/test_meta_schedule_cpu_dot_product.py @@ -43,7 +43,7 @@ def _schedule_dense(m: Optional[int], do_tune: bool, intrin=VNNI_INTRIN): """ def schedule_fn(sch, dense_block: Optional[BlockRV] = None) -> bool: - if sch.mod.attrs is not None and "dense" not in sch.mod.attrs["task_name"]: + if "dense" not in sch.mod.attrs["task_name"]: return False if dense_block is None: assert has_block(sch, "compute") diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 11437f7d682a..fced7a84a832 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -328,7 +328,7 @@ def main( mod = partition_for_cutlass(Conv2dReLU, annotate_codegen=False) for f_var in mod.functions: func = mod[f_var] - if func.attrs and "Composite" in func.attrs: + if "Composite" in func.attrs: # verify that the function is not fused as residual block assert func.attrs["Composite"] == "cutlass.conv2d_bias_relu" @@ -554,7 +554,7 @@ def main( mod = partition_for_cutlass(TransposedMatmul, annotate_codegen=False) for f_var in mod.functions: func = mod[f_var] - if func.attrs and "Composite" in func.attrs: + if "Composite" in func.attrs: # verify that the function is not fused as transposed matmul assert func.attrs["Composite"] == "cutlass.matmul" @@ -575,7 +575,7 @@ def main(x: R.Tensor((128, 128), "float16"), w: R.Tensor((128, 128), "float16")) mod = partition_for_cutlass(Module, annotate_codegen=False) for f_var in mod.functions: func = mod[f_var] - if func.attrs and "Composite" in func.attrs: + if "Composite" in func.attrs: assert func.attrs["Composite"] == "cutlass.matmul" diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index f3498f8ec753..60f8278ec277 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -351,7 +351,7 @@ def test_prim_func(): assert len(func.buffer_map) == 1 f2 = func.with_attr({"calling_conv": 1, "tir.noalias": True}) assert f2.attrs["calling_conv"].value == 1 - assert func.attrs is None + assert not func.attrs def test_vars(): diff --git a/tests/python/tir-transform/test_tir_transform_helpers.py b/tests/python/tir-transform/test_tir_transform_helpers.py index d4cd01ade248..d2ea82a1402c 100644 --- a/tests/python/tir-transform/test_tir_transform_helpers.py +++ b/tests/python/tir-transform/test_tir_transform_helpers.py @@ -33,7 +33,7 @@ def func1(A: T.Buffer((16,), "float32")): mod = MockModule assert mod - assert mod["func1"].attrs is None + assert not mod["func1"].attrs after = tvm.tir.transform.AnnotateEntryFunc()(mod) assert ( after["func1"].attrs @@ -64,8 +64,8 @@ def func2(A: T.Buffer((32,), "float32")): def test_annotate_entry_func_multiple_primfunc(): mod = MockModule assert mod - assert mod["func1"].attrs is None - assert mod["func2"].attrs is None + assert not mod["func1"].attrs + assert not mod["func2"].attrs # This should fail after = tvm.tir.transform.AnnotateEntryFunc()(mod) @@ -75,13 +75,13 @@ def test_bind_target(): assert mod target = tvm.target.Target("cuda") - assert mod["func1"].attrs is None - assert mod["func2"].attrs is None + assert not mod["func1"].attrs + assert not mod["func2"].attrs after = tvm.tir.transform.BindTarget(target)(mod) - assert after["func1"].attrs and "target" in after["func1"].attrs + assert "target" in after["func1"].attrs assert after["func1"].attrs["target"] == target - assert after["func2"].attrs and "target" in after["func2"].attrs + assert "target" in after["func2"].attrs assert after["func2"].attrs["target"] == target @@ -218,7 +218,7 @@ def test_filter_primfunc(): # Test condition that does not filter out anything def checker_filter_out_none(func: tvm.tir.PrimFunc): - return (func.attrs is not None) and ("temp" in func.attrs) + return "temp" in func.attrs after = tvm.tir.transform.Filter(checker_filter_out_none)(mod) assert len(after.functions) == 2 @@ -228,7 +228,7 @@ def checker_filter_out_none(func: tvm.tir.PrimFunc): # Test condition that selectively filters out primfuncs def checker_filter_out_one(func: tvm.tir.PrimFunc): - return (func.attrs is not None) and ("temp" in func.attrs) and func.attrs["temp"] == "test1" + return ("temp" in func.attrs) and func.attrs["temp"] == "test1" after = tvm.tir.transform.Filter(checker_filter_out_one)(mod) assert len(after.functions) == 1 @@ -237,7 +237,7 @@ def checker_filter_out_one(func: tvm.tir.PrimFunc): # Test condition that filters out everything def checker_filter_out_both(func: tvm.tir.PrimFunc): - return (func.attrs is not None) and ("invalid_attr" in func.attrs) + return "invalid_attr" in func.attrs after = tvm.tir.transform.Filter(checker_filter_out_both)(mod) assert len(after.functions) == 0 From 8fba5a6a1f35056757e3bbf50b0e259575dcc53f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 20 Mar 2024 14:36:37 -0500 Subject: [PATCH 3/7] Fix up unit tests --- python/tvm/relay/function.py | 3 +++ python/tvm/relay/transform/recast.py | 2 +- src/relay/ir/function.cc | 2 ++ src/script/ir_builder/relax/ir.cc | 21 ++++++++++++++------- src/script/ir_builder/tir/frame.cc | 2 +- src/script/ir_builder/tir/ir.cc | 5 +---- 6 files changed, 22 insertions(+), 13 deletions(-) diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index 54adb45d8cbe..f1eada9159e1 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -54,6 +54,9 @@ def __init__(self, params, body, ret_type=None, type_params=None, attrs=None, sp if type_params is None: type_params = convert([]) + if attrs is None: + attrs = tvm.ir.make_node("DictAttrs") + self.__init_handle_by_constructor__( _ffi_api.Function, params, body, ret_type, type_params, attrs, span ) diff --git a/python/tvm/relay/transform/recast.py b/python/tvm/relay/transform/recast.py index 36824fb93103..39f07b2eb926 100644 --- a/python/tvm/relay/transform/recast.py +++ b/python/tvm/relay/transform/recast.py @@ -75,7 +75,7 @@ def visit_call(self, call): # If out_dtype is in the attributes, we need to update it. orig_dtype = None - if "out_dtype" in call.attrs.keys(): + if call.attrs is not None and "out_dtype" in call.attrs.keys(): new_attr_dict = {} for attr in call.attrs.keys(): attr_value = call.attrs[attr] diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 8d2477f109a4..b5414b27cf22 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -32,6 +32,8 @@ namespace relay { Function::Function(tvm::Array params, Expr body, Type ret_type, tvm::Array type_params, DictAttrs attrs, Span span) { + CHECK(attrs.defined()); + ObjectPtr n = make_object(); ICHECK(params.defined()); ICHECK(type_params.defined()); diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 285a3a348e3b..60f78c0f58bb 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -84,14 +84,21 @@ void FuncName(const String& name) { void FuncAttrs(Map attrs) { FunctionFrame frame = FindFunctionFrame("R.func_attr"); - if (!frame->attrs.empty()) { - LOG(FATAL) << "ValueError: Duplicate function attrs, previous one is:\n" << frame->attrs; - } - if (attrs.count(tvm::attr::kGlobalSymbol) && frame->is_private.value_or(Bool(false))->value) { - LOG(FATAL) << "ValueError: Specifying a global symbol attribute even though the function is " - "annotated as private"; + for (const auto& [key, value] : attrs) { + if (key == tvm::attr::kGlobalSymbol && frame->is_private.value_or(Bool(false))->value) { + LOG(FATAL) << "ValueError: " + << "A private function may not have the kGlobalSymbol (\"" + << tvm::attr::kGlobalSymbol << "\") attribute. " + << "However, a private function specified the global symbol as " << value; + } + if (auto prev = frame->attrs.Get(key)) { + LOG(FATAL) << "ValueError: " + << "Duplicate R.func_attr annotation for key = \"" << key << "\". " + << "Previous value was " << prev.value() << ", with later definition as " << value; + } else { + frame->attrs.Set(key, value); + } } - frame->attrs = attrs; } void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) { diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 42adebf2d848..f0f7a60911c1 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -32,7 +32,7 @@ void PrimFuncFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); // if the prim func is not private and there isn't already a global symbol, // add a global symbol - if (!is_private && name.defined()) { + if (!is_private && name.defined() && !attrs.count(tvm::attr::kGlobalSymbol)) { attrs.Set(tvm::attr::kGlobalSymbol, name.value()); } diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index aa8b7131ef2b..1ae1051d254d 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -94,7 +94,6 @@ void FuncName(String name) { void FuncAttrs(Map new_attrs) { using namespace tvm::tir; PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr"); - auto attrs = frame->attrs; for (const auto& [key, value] : new_attrs) { if (key == tvm::attr::kGlobalSymbol && frame->is_private) { LOG(FATAL) << "ValueError: " @@ -103,7 +102,7 @@ void FuncAttrs(Map new_attrs) { << "However, a private function specified the global symbol as " << value; } - if (auto prev = attrs.Get(key)) { + if (auto prev = frame->attrs.Get(key)) { LOG(FATAL) << "ValueError: " << "Duplicate prim func annotation for key = \"" << key << "\". " << "Previous value was " << prev.value() << ", with later definition as " << value; @@ -111,8 +110,6 @@ void FuncAttrs(Map new_attrs) { frame->attrs.Set(key, value); } } - - frame->attrs = attrs; } tvm::Type FuncRet(tvm::Type ret_type) { From edd544800c8c82eabeadd278615e18bdebc9a3f3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 21 Mar 2024 19:12:50 -0500 Subject: [PATCH 4/7] More unit test fixes --- python/tvm/testing/aot.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index f6d99090510e..c4d70a26b367 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -1001,12 +1001,12 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"): main = mod else: main = mod["main"] - if main.attrs["output_tensor_names"] is None: + if "output_tensor_names" in main.attrs: + output_tensor_names = main.attrs["output_tensor_names"] + else: output_tensor_names = ( ["output"] if output_count == 1 else [f"output{i}" for i in range(output_count)] ) - else: - output_tensor_names = main.attrs["output_tensor_names"] return dict(zip(output_tensor_names, out)) From 72522bab36411575d912375e91ea9e898b8c063e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 22 Mar 2024 07:24:23 -0500 Subject: [PATCH 5/7] Undo erroneous find/replace --- tests/python/contrib/test_dnnl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/contrib/test_dnnl.py b/tests/python/contrib/test_dnnl.py index 27ff7b8a38dd..c45149fc5f1e 100644 --- a/tests/python/contrib/test_dnnl.py +++ b/tests/python/contrib/test_dnnl.py @@ -1374,7 +1374,7 @@ def _visit(node): op = node.op if isinstance(op, relay.GlobalVar): func = mod[op] - if "Compiler" in func.attrs["Compiler"] == desired_compiler: + if "Compiler" in func.attrs and func.attrs["Compiler"] == desired_compiler: matched_ops.append(op) return else: From beee24607a5a85d5a5b610a3a5f92bdc1078fb11 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 22 Mar 2024 14:25:57 -0500 Subject: [PATCH 6/7] A few more unit tests --- python/tvm/relay/testing/py_converter.py | 6 +++++- python/tvm/tir/function.py | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index 799a106ea613..8e2cbe10822c 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -553,7 +553,11 @@ def visit_call(self, call: Expr): # lowered operator: generate a call to a function that gets the PackedFunc # from TVM's registry - if isinstance(func, Function) and func.attrs.Primitive.value == 1: + if ( + isinstance(func, Function) + and hasattr(func.attrs, "Primitive") + and int(func.attrs.Primitive) == 1 + ): op_call_def, op_call = self.create_op_call(func, call.args, fields) return (op_call, field_defs + [op_call_def]) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index bd44e3f7c3de..eb3c50b409c8 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -80,6 +80,9 @@ def __init__( else: raise TypeError("params can only contain Var or Buffer") + if attrs is None: + attrs = tvm.ir.make_node("DictAttrs") + self.__init_handle_by_constructor__( _ffi_api.PrimFunc, param_list, From 05b23bb11147f7e74d56efbb6859d7b3a98ada08 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 23 Mar 2024 07:23:29 -0500 Subject: [PATCH 7/7] Provide `DictAttrs.get` --- python/tvm/ir/attrs.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index 05fe684635dd..6f0a6dd7d155 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -114,6 +114,10 @@ def keys(self): def __getitem__(self, k): return self._dict().__getitem__(k) + def get(self, key, default=None): + """Get an element with a default value.""" + return self._dict().get(key, default) + def __contains__(self, k): return self._dict().__contains__(k)