Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -983,16 +983,23 @@ class FunctionNode : public BaseFuncNode {
class Function : public BaseFunc {
public:
TVM_DLL explicit Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
bool is_pure = true, DictAttrs attrs = NullValue<DictAttrs>(),
Span span = Span());
bool is_pure, DictAttrs attrs = NullValue<DictAttrs>(),
Span span = Span())
: Function(params, body, ret_struct_info, Optional<Bool>(Bool(is_pure)), attrs, span) {}

TVM_DLL explicit Function(Array<Var> params, Expr body,
Optional<StructInfo> ret_struct_info = NullOpt,
Optional<Bool> is_pure = NullOpt,
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());

/*!
* \brief Mimics the constructor but without body Expr.
* \note ret_struct_info is required, since it can not deduced by the body.
*
* \note `ret_struct_info` and `is_pure` are required, since they
* cannot be deduced from the body.
*/
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info,
bool is_pure = true, DictAttrs attrs = NullValue<DictAttrs>(),
Span span = Span());
TVM_DLL static Function CreateEmpty(Array<Var> params, StructInfo ret_struct_info, bool is_pure,
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/script/ir_builder/relax/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace relax {
* \param is_private Whether the function is annotated as private.
* \return The created ir_builder Function frame.
*/
TVM_DLL FunctionFrame Function(const Bool& is_pure, const Bool& is_private);
TVM_DLL FunctionFrame Function(const Optional<Bool>& is_pure, const Bool& is_private);

/*!
* \brief Add a parameter to the last function frame.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,8 @@ def emit_func_output(
finally:
self.end_scope()

# do not specify ret_struct_info and let constructor deduce
# from seqe.struct_info
# Do not specify ret_struct_info or purity, and let the
# constructor deduce from seqe.struct_info.
func = rx.Function(self._func._params, seqe)
for key, value in self._func._attrs.items():
func = func.with_attr(key, value)
Expand Down
16 changes: 13 additions & 3 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,10 +887,11 @@ def __init__(
params: List[Var],
body: Expr,
ret_struct_info: Optional[StructInfo] = None,
is_pure: Optional[bool] = True,
attrs: Optional[tvm.ir.DictAttrs] = None,
is_pure: Optional[bool] = None,
attrs: Optional[Union[tvm.ir.DictAttrs, Mapping]] = None,
span: Optional[Span] = None,
) -> None:
attrs = Function._normalize_attrs(attrs)
self.__init_handle_by_constructor__(
_ffi_api.Function,
params,
Expand All @@ -906,14 +907,23 @@ def create_empty(
params: List[Var],
ret_struct_info: StructInfo,
is_pure: Optional[bool] = True,
attrs: Optional[tvm.ir.DictAttrs] = None,
attrs: Optional[Union[tvm.ir.DictAttrs, Mapping]] = None,
span: Optional[Span] = None,
):
"""Construct a relax.Function but without body"""
attrs = Function._normalize_attrs(attrs)

return _ffi_api.FunctionCreateEmpty(
params, ret_struct_info, is_pure, attrs, span
) # type: ignore

@staticmethod
def _normalize_attrs(attrs: Optional[Union[tvm.ir.DictAttrs, Mapping]]) -> tvm.ir.DictAttrs:
if attrs is None or isinstance(attrs, tvm.ir.DictAttrs):
return attrs
else:
return tvm.ir.make_node("DictAttrs", **attrs)

def __call__(self, *args):
"""Invoke the global function.

Expand Down
13 changes: 6 additions & 7 deletions python/tvm/relax/frontend/nn/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
with self:
if effects:
with self.builder.function("_initialize_effect"):
with self.builder.dataflow():
outputs = _emit_effect_init(self.builder, effects)
outputs = _emit_effect_init(self.builder, effects)
self.builder.emit_func_output(outputs, params=[])
for method_name, method_spec in zip(spec.method_names, spec.method_specs):
params = _params() # Re-initialize so symbolic shapes not shared across methods
Expand All @@ -132,12 +131,12 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
method_name,
attrs={"num_input": len_args + len_effects}, # type: ignore
):
with self.builder.dataflow():
outputs, inputs = _emit_method(self.builder, method_spec, params, effects)
outputs, inputs = _emit_method(self.builder, method_spec, params, effects)
self.builder.emit_func_output(outputs, inputs)
mod = self.builder.finalize()
assert rx.analysis.well_formed(mod)

mod = rx.transform.ConvertToDataflow(min_size=1)(mod)
return mod, params, ext_mods


Expand All @@ -150,7 +149,7 @@ def _emit_effect_init(
inits = effect.emit_init(prefix, builder)
assert isinstance(inits, list)
outputs.extend(inits)
outputs = builder.emit_output(builder.emit(rx.Tuple(outputs)))
outputs = builder.emit(rx.Tuple(outputs))
return outputs


Expand Down Expand Up @@ -281,9 +280,9 @@ def _detuple(arg, var: rx.Var, builder: BlockBuilder):
for _, effect in effects:
effect_outputs.extend(effect.finalize())
if effect_outputs and spec.effect_mode != "none":
outputs = builder.emit_output(rx.Tuple([_unwrap_ret(outputs), rx.Tuple(effect_outputs)]))
outputs = builder.emit(rx.Tuple([_unwrap_ret(outputs), rx.Tuple(effect_outputs)]))
else:
outputs = builder.emit_output(_unwrap_ret(outputs))
outputs = builder.emit(_unwrap_ret(outputs))
return outputs, inputs


Expand Down
15 changes: 7 additions & 8 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,15 +1897,14 @@ def debug_func(lineno: str, arg_0, arg_1, ...) -> None:
else:
raise TypeError(f"Unsupported type {type(arg)}")

func = rx.ExternFunc("vm.builtin.invoke_debug_func")
call = rx.Call(
func,
[io.effect, rx.StringImm(name), rx.StringImm(_line_info), *converted_args],
sinfo_args=[rx.ObjectStructInfo()],
)
io.effect = BlockBuilder.current().emit(
rx.call_pure_packed(
"vm.builtin.invoke_debug_func",
io.effect,
rx.StringImm(name),
rx.StringImm(_line_info),
*converted_args,
sinfo_args=[rx.ObjectStructInfo()],
),
call,
name_hint=io.effect.name_hint,
)

Expand Down
9 changes: 6 additions & 3 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,17 @@ def to_vdevice(data: Expr, dst_vdevice: Union[py_str, VDevice]) -> Expr:
############################### Function ################################


def function(is_pure: bool = True, is_private: bool = False) -> frame.FunctionFrame:
def function(is_pure: Optional[bool] = None, is_private: bool = False) -> frame.FunctionFrame:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs for is_pure should mention that it's inferred.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And updated.

"""Start a function frame.
Parameters
----------
is_pure: bool
Whether the function is annotated as pure.
is_pure: Optional[bool]

Whether the function is pure. If not explicitly specified,
will be inferred from the function's body.

is_private : bool

Whether the function is annotated as private.

Returns
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/script/parser/relax/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@

############################## R.function ##############################


# this formulation allows us to support having @R.function
# appear as a decorator by itself or to have optional arguments
# like @R.function(pure=False)
def function(
f: Optional[FType] = None, pure: bool = True, private: bool = False
f: Optional[FType] = None, pure: Optional[bool] = None, private: bool = False
) -> Union[Function, FType]:
# pylint: disable=unused-argument
# (pure and private aren't used here, but are used later in parsing)
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None:
local_func_var = relax.Var(node.name, relax.FuncStructInfo(params_sinfo, ret_sinfo))
self.var_table.add(node.name, local_func_var)

purity = find_decorator_annotation(node, "pure")
purity = find_decorator_annotation(node, "pure", default=None)
# treat the function as private if we are inside another function
# or if it has a privacy annotation
privacy = is_inner_function or find_decorator_annotation(node, "private", default=False)
Expand Down Expand Up @@ -367,7 +367,6 @@ def visit_if(self: Parser, node: doc.If) -> None:
@dispatch.register(token="relax", type_name="enter_token")
def enter_token(self: Parser) -> Dict[str, Any]:
def relax_call(self, *args) -> Expr:

args = [convert_to_expr(arg) if isinstance(arg, tuple) else arg for arg in args]

if all(isinstance(x, Expr) for x in args):
Expand Down
23 changes: 20 additions & 3 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,8 @@ TVM_REGISTER_GLOBAL("relax.SeqExpr")

TVM_REGISTER_NODE_TYPE(FunctionNode);

Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info, bool is_pure,
DictAttrs attrs, Span span) {
Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
Optional<Bool> is_pure_override, DictAttrs attrs, Span span) {
// Set the function type.
// For function, we take a conservative approach and require the function type
// to be known at construction time.
Expand Down Expand Up @@ -473,6 +473,23 @@ Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct
ret_struct_info = body_sinfo;
}

bool is_pure = [&]() -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would this behave in the recursive or mutually recursive case? That was the reason for not inferring purity in the first place. Detecting those cases and warning the user that they need to be explicitly annotated would be a reasonable approach

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, the struct info for a GlobalVar is determined only by the forward declaration, and not by its body. The default for this is determined here (here), where a function is pure unless explicitly annotated otherwise.

We have a similar problem with return values, where the inferred return type of a function may not omitted, resulting in incorrect struct inference in the calling scope. I think the long-term solution to both is the same: To represent the lack of information while parsing, and to infer full information as a post-proc.

if (is_pure_override.defined()) {
return is_pure_override.value()->value;
}

if (attrs.defined() && attrs->dict.defined()) {
if (auto opt_force_pure = attrs->dict.Get(relax::attr::kForcePure)) {
bool force_pure = Downcast<IntImm>(opt_force_pure.value())->value;
if (force_pure) {
return true;
}
}
}

return !ContainsImpureCall(body);
}();

FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure);

// set the fields
Expand All @@ -490,7 +507,7 @@ Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct

TVM_REGISTER_GLOBAL("relax.Function")
.set_body_typed([](Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info,
bool is_pure, DictAttrs attrs, Span span) {
Optional<Bool> is_pure, DictAttrs attrs, Span span) {
return Function(params, body, ret_struct_info, is_pure, attrs, span);
});

Expand Down
2 changes: 1 addition & 1 deletion src/script/ir_builder/relax/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void FunctionFrameNode::ExitWithScope() {
tvm::relax::Function func(/*params=*/params,
/*body=*/body,
/*ret_struct_info=*/ret_struct_info,
/*is_pure=*/is_pure.value_or(Bool(true))->value,
/*is_pure=*/is_pure,
/*attrs=*/dict_attrs);
// Step 2: Update IRModule.
if (builder->frames.empty()) {
Expand Down
2 changes: 1 addition & 1 deletion src/script/ir_builder/relax/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable)

/////////////////////////////// Function ////////////////////////////////

FunctionFrame Function(const Bool& is_pure, const Bool& is_private) {
FunctionFrame Function(const Optional<Bool>& is_pure, const Bool& is_private) {
ObjectPtr<FunctionFrameNode> n = make_object<FunctionFrameNode>();
const IRBuilder& ir_builder = IRBuilder::Current();
Optional<tvm::IRModule> mod = NullOpt;
Expand Down
48 changes: 44 additions & 4 deletions tests/python/relax/test_analysis_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def test_inline_prim_func():
y,
),
R.Tensor(ndim=0, dtype="int32"),
is_pure=True,
).with_attr("global_symbol", "foo")
new_mod = tvm.IRModule.from_expr(new_func)
assert not rx.analysis.well_formed(new_mod, check_struct_info=False)
Expand Down Expand Up @@ -563,11 +564,45 @@ def test_unlabeled_impure():
x = rx.Var("x", R.Tensor((), dtype="int32"))
y = rx.Var("y")
block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))])
# print is impure, but the function is not labeled as impure
# Calls to `relax.op.print` are impure, but the function is not
# explicitly labeled as impure. If there is no user-specified
# purity annotation, the function's purity is inferred from the
# body.
func = rx.Function([x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32")).with_attr(
"global_symbol", "foo"
)
assert not func.is_pure
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert rx.analysis.well_formed(mod)


def test_unlabeled_pure():
x = rx.Var("x", R.Tensor((), dtype="int32"))
# There are no calls to impure functions, so the `relax::Function`
# constructor infers that the function is pure.
func = rx.Function([x], rx.SeqExpr([], x), R.Tensor((), dtype="int32")).with_attr(
"global_symbol", "foo"
)
assert func.is_pure
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert rx.analysis.well_formed(mod)


def test_labeled_pure():
x = rx.Var("x", R.Tensor((), dtype="int32"))
y = rx.Var("y")
block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))])
# Calls to `relax.op.print` are impure, but the function is
# explicitly labeled as pure.
func = rx.Function(
[x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32"), is_pure=True
).with_attr("global_symbol", "foo")
# The explicit argument is used for the function's purity.
assert func.is_pure
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))

# The function is ill-formed, as the function's purity does not
# match the explicit annotation.
assert not rx.analysis.well_formed(mod)


Expand All @@ -576,10 +611,11 @@ def test_labeled_impure():
x = rx.Var("x", R.Tensor((), dtype="int32"))
y = rx.Var("y")
block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))])
# print is impure, but the function is not labeled as impure
# print is impure, and the function is labeled as impure.
func = rx.Function(
[x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32"), is_pure=False
).with_attrs({"global_symbol": "foo"})
assert not func.is_pure
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert rx.analysis.well_formed(mod)

Expand All @@ -589,9 +625,13 @@ def test_force_pure():
y = rx.Var("y")
block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))])
# print is impure, but force_pure overrides the judgment
func = rx.Function([x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32")).with_attrs(
{"global_symbol": "foo", "relax.force_pure": True}
func = rx.Function(
[x],
rx.SeqExpr([block], x),
R.Tensor((), dtype="int32"),
attrs={"global_symbol": "foo", "relax.force_pure": True},
)
assert func.is_pure
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert rx.analysis.well_formed(mod)

Expand Down
Loading