diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 18d0f025c776..81611b1a535a 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -230,7 +230,7 @@ class DictAttrs : public Attrs { * \brief Consruct a Attrs backed by DictAttrsNode. * \param dict The attributes. */ - TVM_DLL explicit DictAttrs(Map dict); + TVM_DLL explicit DictAttrs(Map dict = {}); // Utils for accessing attributes // This needs to be on DictAttrs, not DictAttrsNode because we return the default @@ -298,7 +298,7 @@ class DictAttrs : public Attrs { return GetAttr(attr_key, 0).value_or(0).IntValue() != 0; } - TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode); + TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(DictAttrs, Attrs, DictAttrsNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode); }; @@ -415,9 +415,6 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { if (input->attrs.defined()) { TNode* node = input.CopyOnWrite(); node->attrs.CopyOnWrite()->dict.erase(attr_key); - if (node->attrs->dict.size() == 0) { - node->attrs = NullValue(); - } } return input; } diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index ad6efa529cc2..2a5412a5671f 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -376,7 +376,8 @@ class IRModule : public ObjectRef { TVM_DLL explicit IRModule(Map functions, Map type_definitions = {}, std::unordered_set import_set = {}, SourceMap map = {}, - DictAttrs attrs = {}, Map> global_infos = {}); + DictAttrs attrs = DictAttrs(), + Map> global_infos = {}); /*! \brief default constructor */ IRModule() : IRModule(Map({})) {} diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index fdbd7bd8eb2c..4634d1e228d3 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -983,15 +983,14 @@ class FunctionNode : public BaseFuncNode { class Function : public BaseFunc { public: TVM_DLL explicit Function(Array params, Expr body, Optional ret_struct_info, - bool is_pure = true, DictAttrs attrs = NullValue(), - Span span = Span()); + bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span()); /*! * \brief Mimics the constructor but without body Expr. * \note ret_struct_info is required, since it can not deduced by the body. */ TVM_DLL static Function CreateEmpty(Array params, StructInfo ret_struct_info, - bool is_pure = true, DictAttrs attrs = NullValue(), + bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 874d4f233416..798f6d4d2566 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -114,7 +114,7 @@ class Function : public BaseFunc { * \param span The span of the function. */ TVM_DLL Function(tvm::Array params, Expr body, Type ret_type, tvm::Array ty_params, - tvm::DictAttrs attrs = NullValue(), Span span = Span()); + tvm::DictAttrs attrs = DictAttrs(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 92f477b058fd..172316daae59 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -741,14 +741,24 @@ struct ObjectPtrEqual { * \param ParentType The parent type of the objectref * \param ObjectName The type name of the object. */ -#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ +#define TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, \ + ObjectName) \ explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ const ObjectName* operator->() const { return static_cast(data_.get()); } \ const ObjectName* get() const { return operator->(); } \ using ContainerType = ObjectName; +/* + * \brief Define object reference methods. + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, ObjectName) + /* * \brief Define object reference methods that is not nullable. * diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 0cc385d876a8..598750f0ac48 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -78,7 +78,7 @@ class PrimFuncFrameNode : public TIRFrameNode { /*! \brief Maps some parameters to specific Buffer data structures. */ Map buffer_map; /*! \brief Additional attributes storing the meta-data */ - Optional> attrs; + Map attrs; /*! \brief The variable map bound to thread env. */ Map env_threads; /*! \brief The buffer allocated in root block. */ diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 1917a3c22c6e..274ebd0a6558 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -164,7 +164,7 @@ class PrimFunc : public BaseFunc { */ TVM_DLL PrimFunc(Array params, Stmt body, Type ret_type = VoidType(), Map buffer_map = Map(), - DictAttrs attrs = NullValue(), Span span = Span()); + DictAttrs attrs = DictAttrs(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode); diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 80169f51640e..59803f20feb5 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -977,7 +977,7 @@ def handle_norm(self, f, op_type): return f.with_attrs(attrs) def visit_function_(self, f): - if f.attrs is None or "Composite" not in f.attrs: + if b"Composite" not in f.attrs: body = super().visit_expr(f.body) return relax.Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span) diff --git a/python/tvm/contrib/relay_viz/interface.py b/python/tvm/contrib/relay_viz/interface.py index 15dbbf9fd6b6..8df188fcf42e 100644 --- a/python/tvm/contrib/relay_viz/interface.py +++ b/python/tvm/contrib/relay_viz/interface.py @@ -213,14 +213,14 @@ def _function( node_to_id: Dict[relay.Expr, str], ) -> Tuple[Union[VizNode, None], List[VizEdge]]: """Render rule for a relay function node""" - node_details = [] - name = "" func_attrs = node.attrs - if func_attrs: - node_details = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] - # "Composite" might from relay.transform.MergeComposite - if "Composite" in func_attrs.keys(): - name = func_attrs["Composite"] + node_details = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] + # "Composite" might from relay.transform.MergeComposite + if "Composite" in func_attrs.keys(): + name = func_attrs["Composite"] + else: + name = "" + node_id = node_to_id[node] # Body -> FunctionNode @@ -244,11 +244,10 @@ def _call( elif isinstance(node.op, relay.Function): func_attrs = node.op.attrs op_name = "Anonymous Func" - if func_attrs: - node_detail = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] - # "Composite" might from relay.transform.MergeComposite - if "Composite" in func_attrs.keys(): - op_name = func_attrs["Composite"] + node_detail = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] + # "Composite" might from relay.transform.MergeComposite + if "Composite" in func_attrs.keys(): + op_name = func_attrs["Composite"] elif isinstance(node.op, relay.GlobalVar): op_name = "GlobalVar" node_detail = [f"GlobalVar.name_hint: {node.op.name_hint}"] diff --git a/python/tvm/dlight/base/transform.py b/python/tvm/dlight/base/transform.py index 89ecaa6350fb..d697e9440b31 100644 --- a/python/tvm/dlight/base/transform.py +++ b/python/tvm/dlight/base/transform.py @@ -31,8 +31,6 @@ def _is_scheduled(func: tir.PrimFunc) -> bool: if not isinstance(func, tir.PrimFunc): return False - if not func.attrs: - return False if "tir.is_scheduled" not in func.attrs: return False return func.attrs["tir.is_scheduled"] == 1 diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index 9318b9149245..0f224b89f9e4 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -335,7 +335,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring root_block = analysis.get_root_block(sch) blocks = sch.get_child_blocks(root_block) - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + if "dlight.do_not_tensorize" in func.attrs.keys(): return None reduction_blocks = get_reduction_blocks(sch, blocks) @@ -556,7 +556,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring root_block = analysis.get_root_block(sch) blocks = sch.get_child_blocks(root_block) - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + if "dlight.do_not_tensorize" in func.attrs.keys(): return None reduction_blocks = get_reduction_blocks(sch, blocks) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index e23765e92d8c..c332062b37b9 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -249,7 +249,7 @@ def build( if target is None and isinstance(input_mod, tvm.IRModule): target_mod = {} for gvar, func in input_mod.functions.items(): - tgt = func.attrs["target"] if func.attrs and "target" in func.attrs else "llvm" + tgt = func.attrs["target"] if "target" in func.attrs else "llvm" if tgt not in target_mod: target_mod[tgt] = {} target_mod[tgt][gvar] = func diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index 05fe684635dd..6f0a6dd7d155 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -114,6 +114,10 @@ def keys(self): def __getitem__(self, k): return self._dict().__getitem__(k) + def get(self, key, default=None): + """Get an element with a default value.""" + return self._dict().get(key, default) + def __contains__(self, k): return self._dict().__contains__(k) diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py index 57daeea2d97b..c3c24aa631d6 100644 --- a/python/tvm/meta_schedule/relax_integration.py +++ b/python/tvm/meta_schedule/relax_integration.py @@ -138,7 +138,7 @@ def extracted_tasks_to_tune_contexts( get_loggers_from_work_dir(work_dir, [t.task_name for t in extracted_tasks]), fork_seed(seed, n=len(extracted_tasks)), ): - if task.mod.attrs is not None and task.mod.attrs.get("tir.is_scheduled", False): + if task.mod.attrs.get("tir.is_scheduled", False): warnings.warn("The task {task.task_name} is already scheduled, skipping it.") continue tasks.append( diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index a611bee2bbcd..0d9f4ff8e923 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -526,11 +526,11 @@ def __init__(self, mod): super().__init__(mod) def visit_function_(self, f): - if f.attrs is None or "Composite" not in f.attrs: + if "Composite" not in f.attrs: body = super().visit_expr(f.body) new_f = Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span) - if f.attrs and "global_symbol" in f.attrs and "cutlass" in f.attrs["global_symbol"]: + if "global_symbol" in f.attrs and "cutlass" in f.attrs["global_symbol"]: composite_func = body.blocks[0].bindings[0].value if "WorkspaceSize" in composite_func.attrs: return new_f.with_attr("WorkspaceSize", composite_func.attrs["WorkspaceSize"]) diff --git a/python/tvm/relax/frontend/common.py b/python/tvm/relax/frontend/common.py index cc36bbbc72ba..bbd0c55aac2e 100644 --- a/python/tvm/relax/frontend/common.py +++ b/python/tvm/relax/frontend/common.py @@ -42,7 +42,7 @@ def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.n detached_mod = tvm.IRModule() params_dict = dict() for gv, func in mod.functions_items(): - if func.attrs is not None and "params" in func.attrs: + if "params" in func.attrs: params = list(func.attrs["params"]) if not all([isinstance(param, tvm.nd.NDArray) for param in params]): raise ValueError( diff --git a/python/tvm/relax/training/setup_trainer.py b/python/tvm/relax/training/setup_trainer.py index 2e2057086904..71bf8509a63e 100644 --- a/python/tvm/relax/training/setup_trainer.py +++ b/python/tvm/relax/training/setup_trainer.py @@ -138,19 +138,15 @@ def _check_well_formed(self, mod: IRModule): ) from exc # Check function attrs - if ( - mod.attrs is None - or not self.PARAM_NUM_ATTR_KEY in mod.attrs - or not isinstance(mod.attrs[self.PARAM_NUM_ATTR_KEY], IntImm) + if not self.PARAM_NUM_ATTR_KEY in mod.attrs or not isinstance( + mod.attrs[self.PARAM_NUM_ATTR_KEY], IntImm ): raise ValueError( f"SetupTrainer: The backbone module should has an integer attribute named " f"{self.PARAM_NUM_ATTR_KEY}" ) - if ( - mod.attrs is None - or not self.STATE_NUM_ATTR_KEY in mod.attrs - or not isinstance(mod.attrs[self.STATE_NUM_ATTR_KEY], IntImm) + if not self.STATE_NUM_ATTR_KEY in mod.attrs or not isinstance( + mod.attrs[self.STATE_NUM_ATTR_KEY], IntImm ): raise ValueError( f"SetupTrainer: The backbone module should has an integer attribute named " diff --git a/python/tvm/relax/transform/lazy_transform_params.py b/python/tvm/relax/transform/lazy_transform_params.py index e8e8229965c5..1b025f7d3a6a 100644 --- a/python/tvm/relax/transform/lazy_transform_params.py +++ b/python/tvm/relax/transform/lazy_transform_params.py @@ -138,7 +138,7 @@ def __init__( self.memory_free_insertion = None def transform(self, func: relax.Function) -> relax.Function: - if func.attrs is not None and "num_input" in func.attrs: + if "num_input" in func.attrs: num_input = func.attrs["num_input"].value else: num_input = 0 @@ -235,7 +235,7 @@ def __init__(self, func_creator, mod: Optional[IRModule] = None) -> None: super().__init__(mod) def visit_function_(self, func: relax.Function) -> relax.Expr: - if func.attrs is not None and "num_input" in func.attrs: + if "num_input" in func.attrs: num_input = func.attrs["num_input"].value else: num_input = 0 diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index 289754d5c370..a402604b4c11 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -156,7 +156,7 @@ class QPadArgs(Enum): def is_npu_func(func: relay.Function) -> bool: """Check if the given function is an NPU function.""" - return func.attrs and "Compiler" in func.attrs and func.attrs["Compiler"] == "ethos-u" + return "Compiler" in func.attrs and func.attrs["Compiler"] == "ethos-u" def is_composite_func(func: relay.Function, name: str) -> bool: diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index 54adb45d8cbe..f1eada9159e1 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -54,6 +54,9 @@ def __init__(self, params, body, ret_type=None, type_params=None, attrs=None, sp if type_params is None: type_params = convert([]) + if attrs is None: + attrs = tvm.ir.make_node("DictAttrs") + self.__init_handle_by_constructor__( _ffi_api.Function, params, body, ret_type, type_params, attrs, span ) diff --git a/python/tvm/relay/quantize/_partition_conversions.py b/python/tvm/relay/quantize/_partition_conversions.py index 8ba5c9ae2f20..8fec69cdf53e 100644 --- a/python/tvm/relay/quantize/_partition_conversions.py +++ b/python/tvm/relay/quantize/_partition_conversions.py @@ -215,7 +215,7 @@ def partition_prefix(mod, quantized_dtypes): prefix_cutter = PrefixCutter(func.params, quantized_dtypes) mid_body = prefix_cutter.visit(func.body) assert not func.type_params, "unimplemented" - assert func.attrs is None, "unimplemented" + assert not func.attrs, "unimplemented" mid_func = relay.Function(relay.analysis.free_vars(mid_body), mid_body) mid_mod = tvm.IRModule.from_expr(mid_func) mid_mod = relay.transform.InferType()(mid_mod) @@ -288,7 +288,7 @@ def partition_suffix(mod, quantized_dtypes): suffix_cutter = SuffixCutter(quantized_dtypes) post_body = suffix_cutter.visit(func.body) assert not func.type_params, "unimplemented" - assert func.attrs is None, "unimplemented" + assert not func.attrs, "unimplemented" post_func = relay.Function(relay.analysis.free_vars(post_body), post_body, func.ret_type) post_mod = tvm.IRModule.from_expr(post_func) post_mod = relay.transform.InferType()(post_mod) diff --git a/python/tvm/relay/testing/py_converter.py b/python/tvm/relay/testing/py_converter.py index 9cbfcead4783..8e2cbe10822c 100644 --- a/python/tvm/relay/testing/py_converter.py +++ b/python/tvm/relay/testing/py_converter.py @@ -553,7 +553,11 @@ def visit_call(self, call: Expr): # lowered operator: generate a call to a function that gets the PackedFunc # from TVM's registry - if isinstance(func, Function) and func.attrs and func.attrs.Primitive.value == 1: + if ( + isinstance(func, Function) + and hasattr(func.attrs, "Primitive") + and int(func.attrs.Primitive) == 1 + ): op_call_def, op_call = self.create_op_call(func, call.args, fields) return (op_call, field_defs + [op_call_def]) diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index 8d74f545a3c2..c4d70a26b367 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -1001,12 +1001,12 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"): main = mod else: main = mod["main"] - if main.attrs is None or main.attrs["output_tensor_names"] is None: + if "output_tensor_names" in main.attrs: + output_tensor_names = main.attrs["output_tensor_names"] + else: output_tensor_names = ( ["output"] if output_count == 1 else [f"output{i}" for i in range(output_count)] ) - else: - output_tensor_names = main.attrs["output_tensor_names"] return dict(zip(output_tensor_names, out)) diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index bd44e3f7c3de..eb3c50b409c8 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -80,6 +80,9 @@ def __init__( else: raise TypeError("params can only contain Var or Buffer") + if attrs is None: + attrs = tvm.ir.make_node("DictAttrs") + self.__init_handle_by_constructor__( _ffi_api.PrimFunc, param_list, diff --git a/src/relay/analysis/type_solver.cc b/src/relay/analysis/type_solver.cc index 5bd5698d8321..c4fab210acb8 100644 --- a/src/relay/analysis/type_solver.cc +++ b/src/relay/analysis/type_solver.cc @@ -659,7 +659,7 @@ TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver") auto module = IRModule({}, {}); DiagnosticContext diag_ctx = DiagnosticContext::Default(module); auto dummy_fn_name = GlobalVar("test"); - module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array({})), Type(), {}, {})); + module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array({})), Type(), {})); auto solver = std::make_shared(dummy_fn_name, diag_ctx); auto mod = [module, solver, diag_ctx](std::string name) -> PackedFunc { diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index ba94e4b19ec7..48449eb02149 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -194,7 +194,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { CHECK_EQ(before_arity, after_arity); lifted_func = Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(), - free_type_vars, /*attrs=*/{}, func->span); + free_type_vars, DictAttrs(), func->span); lifted_func->virtual_device_ = result_virtual_device; lifted_func = MarkClosure(lifted_func); } diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index ee585446cb26..8e756a8aa2d3 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -438,7 +438,7 @@ Expr InferTypeWithModule(const Expr& expr, const IRModule& m) { if (expr.as()) { func = Downcast(expr); } else { - func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); + func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod)); } mod->Add(gvar, func); mod = transform::InferType()(mod); diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index fd8c646ecf1c..b5414b27cf22 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -32,6 +32,8 @@ namespace relay { Function::Function(tvm::Array params, Expr body, Type ret_type, tvm::Array type_params, DictAttrs attrs, Span span) { + CHECK(attrs.defined()); + ObjectPtr n = make_object(); ICHECK(params.defined()); ICHECK(type_params.defined()); @@ -251,7 +253,7 @@ TVM_REGISTER_GLOBAL("relay.ir.IRModuleUpdateWithRenamer") TVM_REGISTER_GLOBAL("relay.ir.FunctionFromExprInContext") .set_body_typed([](RelayExpr expr, IRModule mod) -> Function { - return Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); + return Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod)); }); TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttr") diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index a989cf53f818..c192097a0b29 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -253,8 +253,7 @@ class DynamicToStaticMutator : public MixedModeMutator { if (auto func_node = expr.as()) { func = func_node.value(); } else { - func = - relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod_), {}); + func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod_)); } mod_->Update(gv_, func); diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index 7c90d101b567..05d49cf5047c 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -170,7 +170,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm, VarMap* vm, Expr reify(const MCont& k) { Var arg = Var("arg", Type()); - return Function({arg}, k(arg), Type(), {}, {}); + return Function({arg}, k(arg), Type(), {}); } Expr reify(const MCont& k, const std::function& cont) { @@ -328,7 +328,7 @@ Function UnCPS(const Function& f) { // TODO(@M.K.): make alphaequal work on free term // ICHECK(tvm::StructuralEqual()(cont_type, Arrow(new_ret_type, answer_type))); auto x = Var("x", new_ret_type); - auto cont = Function({x}, x, new_ret_type, {}, {}); + auto cont = Function({x}, x, new_ret_type, {}); tvm::Array args; for (const auto& p : new_params) { args.push_back(p); diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index 3d917cee887b..60a35ee010ec 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -38,7 +38,7 @@ void IRModuleFrameNode::ExitWithScope() { } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; - auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); + auto dict_attrs = DictAttrs(attrs); builder->result = tvm::IRModule(func_map, {}, {}, {}, dict_attrs, global_infos); } diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 285a3a348e3b..60f78c0f58bb 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -84,14 +84,21 @@ void FuncName(const String& name) { void FuncAttrs(Map attrs) { FunctionFrame frame = FindFunctionFrame("R.func_attr"); - if (!frame->attrs.empty()) { - LOG(FATAL) << "ValueError: Duplicate function attrs, previous one is:\n" << frame->attrs; - } - if (attrs.count(tvm::attr::kGlobalSymbol) && frame->is_private.value_or(Bool(false))->value) { - LOG(FATAL) << "ValueError: Specifying a global symbol attribute even though the function is " - "annotated as private"; + for (const auto& [key, value] : attrs) { + if (key == tvm::attr::kGlobalSymbol && frame->is_private.value_or(Bool(false))->value) { + LOG(FATAL) << "ValueError: " + << "A private function may not have the kGlobalSymbol (\"" + << tvm::attr::kGlobalSymbol << "\") attribute. " + << "However, a private function specified the global symbol as " << value; + } + if (auto prev = frame->attrs.Get(key)) { + LOG(FATAL) << "ValueError: " + << "Duplicate R.func_attr annotation for key = \"" << key << "\". " + << "Previous value was " << prev.value() << ", with later definition as " << value; + } else { + frame->attrs.Set(key, value); + } } - frame->attrs = attrs; } void FuncRetStructInfo(const tvm::relax::StructInfo& ret_sinfo) { diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index c15a290bf03d..f0f7a60911c1 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -32,18 +32,8 @@ void PrimFuncFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); // if the prim func is not private and there isn't already a global symbol, // add a global symbol - if (!is_private && name.defined()) { - if (!attrs.defined()) { - attrs = {{tvm::attr::kGlobalSymbol, name.value()}}; - } else if (!attrs.value().count(tvm::attr::kGlobalSymbol)) { - // copy over attributes (can't mutate the dict inside the optional in-place) - Map new_attrs; - for (auto kv : attrs.value()) { - new_attrs.Set(kv.first, kv.second); - } - new_attrs.Set(tvm::attr::kGlobalSymbol, name.value()); - attrs = std::move(new_attrs); - } + if (!is_private && name.defined() && !attrs.count(tvm::attr::kGlobalSymbol)) { + attrs.Set(tvm::attr::kGlobalSymbol, name.value()); } tvm::tir::PrimFunc func( @@ -51,7 +41,7 @@ void PrimFuncFrameNode::ExitWithScope() { /*body=*/AsStmt(stmts), /*ret_type=*/ret_type.value_or(TupleType::Empty()), /*buffer_map=*/buffer_map, - /*attrs=*/attrs.defined() ? DictAttrs(attrs.value()) : NullValue()); + /*attrs=*/DictAttrs(attrs)); func = tvm::tir::ScriptComplete(func, root_alloc_buffers); IRBuilder builder = IRBuilder::Current(); if (builder->frames.empty()) { diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index cf73ffa0eedd..1ae1051d254d 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -61,7 +61,7 @@ PrimFuncFrame PrimFunc(bool is_private) { n->args.clear(); n->ret_type = NullOpt; n->buffer_map.clear(); - n->attrs = NullOpt; + n->attrs = {}; n->env_threads.clear(); n->root_alloc_buffers.clear(); return PrimFuncFrame(n); @@ -91,17 +91,25 @@ void FuncName(String name) { frame->name = name; } -void FuncAttrs(Map attrs) { +void FuncAttrs(Map new_attrs) { using namespace tvm::tir; PrimFuncFrame frame = FindPrimFuncFrame("T.func_attr"); - if (frame->attrs.defined()) { - LOG(FATAL) << "ValueError: Duplicate prim func annotations, previous one is " << frame->attrs; - } - if (attrs.count(tvm::attr::kGlobalSymbol) && frame->is_private) { - LOG(FATAL) << "ValueError: Specifying the global symbol even though the PrimFunc is annotated " - "as private"; + for (const auto& [key, value] : new_attrs) { + if (key == tvm::attr::kGlobalSymbol && frame->is_private) { + LOG(FATAL) << "ValueError: " + << "A private function may not have the kGlobalSymbol (\"" + << tvm::attr::kGlobalSymbol << "\") attribute. " + << "However, a private function specified the global symbol as " << value; + } + + if (auto prev = frame->attrs.Get(key)) { + LOG(FATAL) << "ValueError: " + << "Duplicate prim func annotation for key = \"" << key << "\". " + << "Previous value was " << prev.value() << ", with later definition as " << value; + } else { + frame->attrs.Set(key, value); + } } - frame->attrs = attrs; } tvm::Type FuncRet(tvm::Type ret_type) { diff --git a/tests/python/contrib/test_coreml_codegen.py b/tests/python/contrib/test_coreml_codegen.py index 2edfafaa0bd8..f0cdf14aa019 100644 --- a/tests/python/contrib/test_coreml_codegen.py +++ b/tests/python/contrib/test_coreml_codegen.py @@ -140,7 +140,7 @@ def _construct_model(func, m1, m2): fcompile = tvm._ffi.get_global_func("relay.ext.coremlcompiler") for var, func in mod.functions.items(): - if func.attrs and "Compiler" in func.attrs and func.attrs["Compiler"] == "coremlcompiler": + if "Compiler" in func.attrs and func.attrs["Compiler"] == "coremlcompiler": fcompile(func) diff --git a/tests/python/meta_schedule/test_meta_schedule_cpu_dot_product.py b/tests/python/meta_schedule/test_meta_schedule_cpu_dot_product.py index 592c772a04dd..cc2731ff5974 100644 --- a/tests/python/meta_schedule/test_meta_schedule_cpu_dot_product.py +++ b/tests/python/meta_schedule/test_meta_schedule_cpu_dot_product.py @@ -43,7 +43,7 @@ def _schedule_dense(m: Optional[int], do_tune: bool, intrin=VNNI_INTRIN): """ def schedule_fn(sch, dense_block: Optional[BlockRV] = None) -> bool: - if sch.mod.attrs is not None and "dense" not in sch.mod.attrs["task_name"]: + if "dense" not in sch.mod.attrs["task_name"]: return False if dense_block is None: assert has_block(sch, "compute") diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 11437f7d682a..fced7a84a832 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -328,7 +328,7 @@ def main( mod = partition_for_cutlass(Conv2dReLU, annotate_codegen=False) for f_var in mod.functions: func = mod[f_var] - if func.attrs and "Composite" in func.attrs: + if "Composite" in func.attrs: # verify that the function is not fused as residual block assert func.attrs["Composite"] == "cutlass.conv2d_bias_relu" @@ -554,7 +554,7 @@ def main( mod = partition_for_cutlass(TransposedMatmul, annotate_codegen=False) for f_var in mod.functions: func = mod[f_var] - if func.attrs and "Composite" in func.attrs: + if "Composite" in func.attrs: # verify that the function is not fused as transposed matmul assert func.attrs["Composite"] == "cutlass.matmul" @@ -575,7 +575,7 @@ def main(x: R.Tensor((128, 128), "float16"), w: R.Tensor((128, 128), "float16")) mod = partition_for_cutlass(Module, annotate_codegen=False) for f_var in mod.functions: func = mod[f_var] - if func.attrs and "Composite" in func.attrs: + if "Composite" in func.attrs: assert func.attrs["Composite"] == "cutlass.matmul" diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index f3498f8ec753..60f8278ec277 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -351,7 +351,7 @@ def test_prim_func(): assert len(func.buffer_map) == 1 f2 = func.with_attr({"calling_conv": 1, "tir.noalias": True}) assert f2.attrs["calling_conv"].value == 1 - assert func.attrs is None + assert not func.attrs def test_vars(): diff --git a/tests/python/tir-transform/test_tir_transform_helpers.py b/tests/python/tir-transform/test_tir_transform_helpers.py index d4cd01ade248..d2ea82a1402c 100644 --- a/tests/python/tir-transform/test_tir_transform_helpers.py +++ b/tests/python/tir-transform/test_tir_transform_helpers.py @@ -33,7 +33,7 @@ def func1(A: T.Buffer((16,), "float32")): mod = MockModule assert mod - assert mod["func1"].attrs is None + assert not mod["func1"].attrs after = tvm.tir.transform.AnnotateEntryFunc()(mod) assert ( after["func1"].attrs @@ -64,8 +64,8 @@ def func2(A: T.Buffer((32,), "float32")): def test_annotate_entry_func_multiple_primfunc(): mod = MockModule assert mod - assert mod["func1"].attrs is None - assert mod["func2"].attrs is None + assert not mod["func1"].attrs + assert not mod["func2"].attrs # This should fail after = tvm.tir.transform.AnnotateEntryFunc()(mod) @@ -75,13 +75,13 @@ def test_bind_target(): assert mod target = tvm.target.Target("cuda") - assert mod["func1"].attrs is None - assert mod["func2"].attrs is None + assert not mod["func1"].attrs + assert not mod["func2"].attrs after = tvm.tir.transform.BindTarget(target)(mod) - assert after["func1"].attrs and "target" in after["func1"].attrs + assert "target" in after["func1"].attrs assert after["func1"].attrs["target"] == target - assert after["func2"].attrs and "target" in after["func2"].attrs + assert "target" in after["func2"].attrs assert after["func2"].attrs["target"] == target @@ -218,7 +218,7 @@ def test_filter_primfunc(): # Test condition that does not filter out anything def checker_filter_out_none(func: tvm.tir.PrimFunc): - return (func.attrs is not None) and ("temp" in func.attrs) + return "temp" in func.attrs after = tvm.tir.transform.Filter(checker_filter_out_none)(mod) assert len(after.functions) == 2 @@ -228,7 +228,7 @@ def checker_filter_out_none(func: tvm.tir.PrimFunc): # Test condition that selectively filters out primfuncs def checker_filter_out_one(func: tvm.tir.PrimFunc): - return (func.attrs is not None) and ("temp" in func.attrs) and func.attrs["temp"] == "test1" + return ("temp" in func.attrs) and func.attrs["temp"] == "test1" after = tvm.tir.transform.Filter(checker_filter_out_one)(mod) assert len(after.functions) == 1 @@ -237,7 +237,7 @@ def checker_filter_out_one(func: tvm.tir.PrimFunc): # Test condition that filters out everything def checker_filter_out_both(func: tvm.tir.PrimFunc): - return (func.attrs is not None) and ("invalid_attr" in func.attrs) + return "invalid_attr" in func.attrs after = tvm.tir.transform.Filter(checker_filter_out_both)(mod) assert len(after.functions) == 0