Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class DictAttrs : public Attrs {
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
*/
TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);
TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict = {});

// Utils for accessing attributes
// This needs to be on DictAttrs, not DictAttrsNode because we return the default
Expand Down Expand Up @@ -298,7 +298,7 @@ class DictAttrs : public Attrs {
return GetAttr<Integer>(attr_key, 0).value_or(0).IntValue() != 0;
}

TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
};

Expand Down Expand Up @@ -415,9 +415,6 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
if (input->attrs.defined()) {
TNode* node = input.CopyOnWrite();
node->attrs.CopyOnWrite()->dict.erase(attr_key);
if (node->attrs->dict.size() == 0) {
node->attrs = NullValue<DictAttrs>();
}
}
return input;
}
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ class IRModule : public ObjectRef {
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<String> import_set = {}, SourceMap map = {},
DictAttrs attrs = {}, Map<String, Array<GlobalInfo>> global_infos = {});
DictAttrs attrs = DictAttrs(),
Map<String, Array<GlobalInfo>> global_infos = {});

/*! \brief default constructor */
IRModule() : IRModule(Map<GlobalVar, BaseFunc>({})) {}
Expand Down
5 changes: 2 additions & 3 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -983,15 +983,14 @@ class FunctionNode : public BaseFuncNode {
class Function : public BaseFunc {
public:
TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
bool is_pure = true, DictAttrs attrs = NullValue<DictAttrs>(),
Span span = Span());
bool is_pure = true, DictAttrs attrs = DictAttrs(), Span span = Span());

/*!
* \brief Mimics the constructor but without body Expr.
* \note ret_struct_info is required, since it can not deduced by the body.
*/
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info,
bool is_pure = true, DictAttrs attrs = NullValue<DictAttrs>(),
bool is_pure = true, DictAttrs attrs = DictAttrs(),
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class Function : public BaseFunc {
* \param span The span of the function.
*/
TVM_DLL Function(tvm::Array<Var> params, Expr body, Type ret_type, tvm::Array<TypeVar> ty_params,
tvm::DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
tvm::DictAttrs attrs = DictAttrs(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
Expand Down
14 changes: 12 additions & 2 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -741,14 +741,24 @@ struct ObjectPtrEqual {
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
*/
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() = default; \
#define TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, \
ObjectName) \
explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
const ObjectName* get() const { return operator->(); } \
using ContainerType = ObjectName;

/*
* \brief Define object reference methods.
* \param TypeName The object type name
* \param ParentType The parent type of the objectref
* \param ObjectName The type name of the object.
*/
#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \
TypeName() = default; \
TVM_DEFINE_OBJECT_REF_METHODS_WITHOUT_DEFAULT_CONSTRUCTOR(TypeName, ParentType, ObjectName)

/*
* \brief Define object reference methods that is not nullable.
*
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class PrimFuncFrameNode : public TIRFrameNode {
/*! \brief Maps some parameters to specific Buffer data structures. */
Map<tvm::tir::Var, tvm::tir::Buffer> buffer_map;
/*! \brief Additional attributes storing the meta-data */
Optional<Map<String, ObjectRef>> attrs;
Map<String, ObjectRef> attrs;
/*! \brief The variable map bound to thread env. */
Map<tvm::tir::Var, tvm::tir::IterVar> env_threads;
/*! \brief The buffer allocated in root block. */
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class PrimFunc : public BaseFunc {
*/
TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
DictAttrs attrs = DictAttrs(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ def handle_norm(self, f, op_type):
return f.with_attrs(attrs)

def visit_function_(self, f):
if f.attrs is None or "Composite" not in f.attrs:
if b"Composite" not in f.attrs:
body = super().visit_expr(f.body)
return relax.Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span)

Expand Down
23 changes: 11 additions & 12 deletions python/tvm/contrib/relay_viz/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,14 @@ def _function(
node_to_id: Dict[relay.Expr, str],
) -> Tuple[Union[VizNode, None], List[VizEdge]]:
"""Render rule for a relay function node"""
node_details = []
name = ""
func_attrs = node.attrs
if func_attrs:
node_details = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()]
# "Composite" might from relay.transform.MergeComposite
if "Composite" in func_attrs.keys():
name = func_attrs["Composite"]
node_details = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()]
# "Composite" might from relay.transform.MergeComposite
if "Composite" in func_attrs.keys():
name = func_attrs["Composite"]
else:
name = ""

node_id = node_to_id[node]

# Body -> FunctionNode
Expand All @@ -244,11 +244,10 @@ def _call(
elif isinstance(node.op, relay.Function):
func_attrs = node.op.attrs
op_name = "Anonymous Func"
if func_attrs:
node_detail = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()]
# "Composite" might from relay.transform.MergeComposite
if "Composite" in func_attrs.keys():
op_name = func_attrs["Composite"]
node_detail = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()]
# "Composite" might from relay.transform.MergeComposite
if "Composite" in func_attrs.keys():
op_name = func_attrs["Composite"]
elif isinstance(node.op, relay.GlobalVar):
op_name = "GlobalVar"
node_detail = [f"GlobalVar.name_hint: {node.op.name_hint}"]
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/dlight/base/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,6 @@
def _is_scheduled(func: tir.PrimFunc) -> bool:
if not isinstance(func, tir.PrimFunc):
return False
if not func.attrs:
return False
if "tir.is_scheduled" not in func.attrs:
return False
return func.attrs["tir.is_scheduled"] == 1
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)

if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys():
if "dlight.do_not_tensorize" in func.attrs.keys():
return None

reduction_blocks = get_reduction_blocks(sch, blocks)
Expand Down Expand Up @@ -556,7 +556,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)

if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys():
if "dlight.do_not_tensorize" in func.attrs.keys():
return None

reduction_blocks = get_reduction_blocks(sch, blocks)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def build(
if target is None and isinstance(input_mod, tvm.IRModule):
target_mod = {}
for gvar, func in input_mod.functions.items():
tgt = func.attrs["target"] if func.attrs and "target" in func.attrs else "llvm"
tgt = func.attrs["target"] if "target" in func.attrs else "llvm"
if tgt not in target_mod:
target_mod[tgt] = {}
target_mod[tgt][gvar] = func
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/ir/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ def keys(self):
def __getitem__(self, k):
return self._dict().__getitem__(k)

def get(self, key, default=None):
"""Get an element with a default value."""
return self._dict().get(key, default)

def __contains__(self, k):
return self._dict().__contains__(k)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/relax_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def extracted_tasks_to_tune_contexts(
get_loggers_from_work_dir(work_dir, [t.task_name for t in extracted_tasks]),
fork_seed(seed, n=len(extracted_tasks)),
):
if task.mod.attrs is not None and task.mod.attrs.get("tir.is_scheduled", False):
if task.mod.attrs.get("tir.is_scheduled", False):
warnings.warn("The task {task.task_name} is already scheduled, skipping it.")
continue
tasks.append(
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,11 +526,11 @@ def __init__(self, mod):
super().__init__(mod)

def visit_function_(self, f):
if f.attrs is None or "Composite" not in f.attrs:
if "Composite" not in f.attrs:
body = super().visit_expr(f.body)
new_f = Function(f.params, body, f.ret_struct_info, f.is_pure, f.attrs, f.span)

if f.attrs and "global_symbol" in f.attrs and "cutlass" in f.attrs["global_symbol"]:
if "global_symbol" in f.attrs and "cutlass" in f.attrs["global_symbol"]:
composite_func = body.blocks[0].bindings[0].value
if "WorkspaceSize" in composite_func.attrs:
return new_f.with_attr("WorkspaceSize", composite_func.attrs["WorkspaceSize"])
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def detach_params(mod: tvm.IRModule) -> Tuple[tvm.IRModule, Dict[str, List[tvm.n
detached_mod = tvm.IRModule()
params_dict = dict()
for gv, func in mod.functions_items():
if func.attrs is not None and "params" in func.attrs:
if "params" in func.attrs:
params = list(func.attrs["params"])
if not all([isinstance(param, tvm.nd.NDArray) for param in params]):
raise ValueError(
Expand Down
12 changes: 4 additions & 8 deletions python/tvm/relax/training/setup_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,19 +138,15 @@ def _check_well_formed(self, mod: IRModule):
) from exc

# Check function attrs
if (
mod.attrs is None
or not self.PARAM_NUM_ATTR_KEY in mod.attrs
or not isinstance(mod.attrs[self.PARAM_NUM_ATTR_KEY], IntImm)
if not self.PARAM_NUM_ATTR_KEY in mod.attrs or not isinstance(
mod.attrs[self.PARAM_NUM_ATTR_KEY], IntImm
):
raise ValueError(
f"SetupTrainer: The backbone module should has an integer attribute named "
f"{self.PARAM_NUM_ATTR_KEY}"
)
if (
mod.attrs is None
or not self.STATE_NUM_ATTR_KEY in mod.attrs
or not isinstance(mod.attrs[self.STATE_NUM_ATTR_KEY], IntImm)
if not self.STATE_NUM_ATTR_KEY in mod.attrs or not isinstance(
mod.attrs[self.STATE_NUM_ATTR_KEY], IntImm
):
raise ValueError(
f"SetupTrainer: The backbone module should has an integer attribute named "
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relax/transform/lazy_transform_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(
self.memory_free_insertion = None

def transform(self, func: relax.Function) -> relax.Function:
if func.attrs is not None and "num_input" in func.attrs:
if "num_input" in func.attrs:
num_input = func.attrs["num_input"].value
else:
num_input = 0
Expand Down Expand Up @@ -235,7 +235,7 @@ def __init__(self, func_creator, mod: Optional[IRModule] = None) -> None:
super().__init__(mod)

def visit_function_(self, func: relax.Function) -> relax.Expr:
if func.attrs is not None and "num_input" in func.attrs:
if "num_input" in func.attrs:
num_input = func.attrs["num_input"].value
else:
num_input = 0
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class QPadArgs(Enum):

def is_npu_func(func: relay.Function) -> bool:
"""Check if the given function is an NPU function."""
return func.attrs and "Compiler" in func.attrs and func.attrs["Compiler"] == "ethos-u"
return "Compiler" in func.attrs and func.attrs["Compiler"] == "ethos-u"


def is_composite_func(func: relay.Function, name: str) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/relay/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def __init__(self, params, body, ret_type=None, type_params=None, attrs=None, sp
if type_params is None:
type_params = convert([])

if attrs is None:
attrs = tvm.ir.make_node("DictAttrs")

self.__init_handle_by_constructor__(
_ffi_api.Function, params, body, ret_type, type_params, attrs, span
)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/quantize/_partition_conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def partition_prefix(mod, quantized_dtypes):
prefix_cutter = PrefixCutter(func.params, quantized_dtypes)
mid_body = prefix_cutter.visit(func.body)
assert not func.type_params, "unimplemented"
assert func.attrs is None, "unimplemented"
assert not func.attrs, "unimplemented"
mid_func = relay.Function(relay.analysis.free_vars(mid_body), mid_body)
mid_mod = tvm.IRModule.from_expr(mid_func)
mid_mod = relay.transform.InferType()(mid_mod)
Expand Down Expand Up @@ -288,7 +288,7 @@ def partition_suffix(mod, quantized_dtypes):
suffix_cutter = SuffixCutter(quantized_dtypes)
post_body = suffix_cutter.visit(func.body)
assert not func.type_params, "unimplemented"
assert func.attrs is None, "unimplemented"
assert not func.attrs, "unimplemented"
post_func = relay.Function(relay.analysis.free_vars(post_body), post_body, func.ret_type)
post_mod = tvm.IRModule.from_expr(post_func)
post_mod = relay.transform.InferType()(post_mod)
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relay/testing/py_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,11 @@ def visit_call(self, call: Expr):

# lowered operator: generate a call to a function that gets the PackedFunc
# from TVM's registry
if isinstance(func, Function) and func.attrs and func.attrs.Primitive.value == 1:
if (
isinstance(func, Function)
and hasattr(func.attrs, "Primitive")
and int(func.attrs.Primitive) == 1
):
op_call_def, op_call = self.create_op_call(func, call.args, fields)
return (op_call, field_defs + [op_call_def])

Expand Down
6 changes: 3 additions & 3 deletions python/tvm/testing/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,12 +1001,12 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"):
main = mod
else:
main = mod["main"]
if main.attrs is None or main.attrs["output_tensor_names"] is None:
if "output_tensor_names" in main.attrs:
output_tensor_names = main.attrs["output_tensor_names"]
else:
output_tensor_names = (
["output"] if output_count == 1 else [f"output{i}" for i in range(output_count)]
)
else:
output_tensor_names = main.attrs["output_tensor_names"]

return dict(zip(output_tensor_names, out))

Expand Down
3 changes: 3 additions & 0 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def __init__(
else:
raise TypeError("params can only contain Var or Buffer")

if attrs is None:
attrs = tvm.ir.make_node("DictAttrs")

self.__init_handle_by_constructor__(
_ffi_api.PrimFunc,
param_list,
Expand Down
2 changes: 1 addition & 1 deletion src/relay/analysis/type_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver")
auto module = IRModule({}, {});
DiagnosticContext diag_ctx = DiagnosticContext::Default(module);
auto dummy_fn_name = GlobalVar("test");
module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array<relay::Expr>({})), Type(), {}, {}));
module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array<relay::Expr>({})), Type(), {}));
auto solver = std::make_shared<TypeSolver>(dummy_fn_name, diag_ctx);

auto mod = [module, solver, diag_ctx](std::string name) -> PackedFunc {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
CHECK_EQ(before_arity, after_arity);
lifted_func =
Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(),
free_type_vars, /*attrs=*/{}, func->span);
free_type_vars, DictAttrs(), func->span);
lifted_func->virtual_device_ = result_virtual_device;
lifted_func = MarkClosure(lifted_func);
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ Expr InferTypeWithModule(const Expr& expr, const IRModule& m) {
if (expr.as<FunctionNode>()) {
func = Downcast<Function>(expr);
} else {
func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod));
}
mod->Add(gvar, func);
mod = transform::InferType()(mod);
Expand Down
4 changes: 3 additions & 1 deletion src/relay/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ namespace relay {

Function::Function(tvm::Array<Var> params, Expr body, Type ret_type,
tvm::Array<TypeVar> type_params, DictAttrs attrs, Span span) {
CHECK(attrs.defined());

ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
ICHECK(params.defined());
ICHECK(type_params.defined());
Expand Down Expand Up @@ -251,7 +253,7 @@ TVM_REGISTER_GLOBAL("relay.ir.IRModuleUpdateWithRenamer")

TVM_REGISTER_GLOBAL("relay.ir.FunctionFromExprInContext")
.set_body_typed([](RelayExpr expr, IRModule mod) -> Function {
return Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
return Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod));
});

TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttr")
Expand Down
Loading