diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index fdbd7bd8eb2c..58c3887d2988 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -983,16 +983,23 @@ class FunctionNode : public BaseFuncNode { class Function : public BaseFunc { public: TVM_DLL explicit Function(Array params, Expr body, Optional ret_struct_info, - bool is_pure = true, DictAttrs attrs = NullValue(), - Span span = Span()); + bool is_pure, DictAttrs attrs = NullValue(), + Span span = Span()) + : Function(params, body, ret_struct_info, Optional(Bool(is_pure)), attrs, span) {} + + TVM_DLL explicit Function(Array params, Expr body, + Optional ret_struct_info = NullOpt, + Optional is_pure = NullOpt, + DictAttrs attrs = NullValue(), Span span = Span()); /*! * \brief Mimics the constructor but without body Expr. - * \note ret_struct_info is required, since it can not deduced by the body. + * + * \note `ret_struct_info` and `is_pure` are required, since they + * cannot be deduced from the body. */ - TVM_DLL static Function CreateEmpty(Array params, StructInfo ret_struct_info, - bool is_pure = true, DictAttrs attrs = NullValue(), - Span span = Span()); + TVM_DLL static Function CreateEmpty(Array params, StructInfo ret_struct_info, bool is_pure, + DictAttrs attrs = NullValue(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode); TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode); diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index d160ad090e48..42d3de27a329 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -37,7 +37,7 @@ namespace relax { * \param is_private Whether the function is annotated as private. * \return The created ir_builder Function frame. */ -TVM_DLL FunctionFrame Function(const Bool& is_pure, const Bool& is_private); +TVM_DLL FunctionFrame Function(const Optional& is_pure, const Bool& is_private); /*! * \brief Add a parameter to the last function frame. diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 330585599d08..c18987ef2e46 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -638,8 +638,8 @@ def emit_func_output( finally: self.end_scope() - # do not specify ret_struct_info and let constructor deduce - # from seqe.struct_info + # Do not specify ret_struct_info or purity, and let the + # constructor deduce from seqe.struct_info. func = rx.Function(self._func._params, seqe) for key, value in self._func._attrs.items(): func = func.with_attr(key, value) diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 12f08f4dbf1a..9168afabb18e 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -887,10 +887,11 @@ def __init__( params: List[Var], body: Expr, ret_struct_info: Optional[StructInfo] = None, - is_pure: Optional[bool] = True, - attrs: Optional[tvm.ir.DictAttrs] = None, + is_pure: Optional[bool] = None, + attrs: Optional[Union[tvm.ir.DictAttrs, Mapping]] = None, span: Optional[Span] = None, ) -> None: + attrs = Function._normalize_attrs(attrs) self.__init_handle_by_constructor__( _ffi_api.Function, params, @@ -906,14 +907,23 @@ def create_empty( params: List[Var], ret_struct_info: StructInfo, is_pure: Optional[bool] = True, - attrs: Optional[tvm.ir.DictAttrs] = None, + attrs: Optional[Union[tvm.ir.DictAttrs, Mapping]] = None, span: Optional[Span] = None, ): """Construct a relax.Function but without body""" + attrs = Function._normalize_attrs(attrs) + return _ffi_api.FunctionCreateEmpty( params, ret_struct_info, is_pure, attrs, span ) # type: ignore + @staticmethod + def _normalize_attrs(attrs: Optional[Union[tvm.ir.DictAttrs, Mapping]]) -> tvm.ir.DictAttrs: + if attrs is None or isinstance(attrs, tvm.ir.DictAttrs): + return attrs + else: + return tvm.ir.make_node("DictAttrs", **attrs) + def __call__(self, *args): """Invoke the global function. diff --git a/python/tvm/relax/frontend/nn/exporter.py b/python/tvm/relax/frontend/nn/exporter.py index 1a7dcd6a648b..853d9ba0278b 100644 --- a/python/tvm/relax/frontend/nn/exporter.py +++ b/python/tvm/relax/frontend/nn/exporter.py @@ -117,8 +117,7 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]: with self: if effects: with self.builder.function("_initialize_effect"): - with self.builder.dataflow(): - outputs = _emit_effect_init(self.builder, effects) + outputs = _emit_effect_init(self.builder, effects) self.builder.emit_func_output(outputs, params=[]) for method_name, method_spec in zip(spec.method_names, spec.method_specs): params = _params() # Re-initialize so symbolic shapes not shared across methods @@ -132,12 +131,12 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]: method_name, attrs={"num_input": len_args + len_effects}, # type: ignore ): - with self.builder.dataflow(): - outputs, inputs = _emit_method(self.builder, method_spec, params, effects) + outputs, inputs = _emit_method(self.builder, method_spec, params, effects) self.builder.emit_func_output(outputs, inputs) mod = self.builder.finalize() assert rx.analysis.well_formed(mod) + mod = rx.transform.ConvertToDataflow(min_size=1)(mod) return mod, params, ext_mods @@ -150,7 +149,7 @@ def _emit_effect_init( inits = effect.emit_init(prefix, builder) assert isinstance(inits, list) outputs.extend(inits) - outputs = builder.emit_output(builder.emit(rx.Tuple(outputs))) + outputs = builder.emit(rx.Tuple(outputs)) return outputs @@ -281,9 +280,9 @@ def _detuple(arg, var: rx.Var, builder: BlockBuilder): for _, effect in effects: effect_outputs.extend(effect.finalize()) if effect_outputs and spec.effect_mode != "none": - outputs = builder.emit_output(rx.Tuple([_unwrap_ret(outputs), rx.Tuple(effect_outputs)])) + outputs = builder.emit(rx.Tuple([_unwrap_ret(outputs), rx.Tuple(effect_outputs)])) else: - outputs = builder.emit_output(_unwrap_ret(outputs)) + outputs = builder.emit(_unwrap_ret(outputs)) return outputs, inputs diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index d299d3943944..2a0d597bb0d5 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1897,15 +1897,14 @@ def debug_func(lineno: str, arg_0, arg_1, ...) -> None: else: raise TypeError(f"Unsupported type {type(arg)}") + func = rx.ExternFunc("vm.builtin.invoke_debug_func") + call = rx.Call( + func, + [io.effect, rx.StringImm(name), rx.StringImm(_line_info), *converted_args], + sinfo_args=[rx.ObjectStructInfo()], + ) io.effect = BlockBuilder.current().emit( - rx.call_pure_packed( - "vm.builtin.invoke_debug_func", - io.effect, - rx.StringImm(name), - rx.StringImm(_line_info), - *converted_args, - sinfo_args=[rx.ObjectStructInfo()], - ), + call, name_hint=io.effect.name_hint, ) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 3e1927290dcc..03a79ab1b9c0 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -222,14 +222,17 @@ def to_vdevice(data: Expr, dst_vdevice: Union[py_str, VDevice]) -> Expr: ############################### Function ################################ -def function(is_pure: bool = True, is_private: bool = False) -> frame.FunctionFrame: +def function(is_pure: Optional[bool] = None, is_private: bool = False) -> frame.FunctionFrame: """Start a function frame. Parameters ---------- - is_pure: bool - Whether the function is annotated as pure. + is_pure: Optional[bool] + + Whether the function is pure. If not explicitly specified, + will be inferred from the function's body. is_private : bool + Whether the function is annotated as private. Returns diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index d5950dc66dce..fa0c3a87337b 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -47,11 +47,12 @@ ############################## R.function ############################## + # this formulation allows us to support having @R.function # appear as a decorator by itself or to have optional arguments # like @R.function(pure=False) def function( - f: Optional[FType] = None, pure: bool = True, private: bool = False + f: Optional[FType] = None, pure: Optional[bool] = None, private: bool = False ) -> Union[Function, FType]: # pylint: disable=unused-argument # (pure and private aren't used here, but are used later in parsing) diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 1a0c3cea8e0b..7b77fd99ce4f 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -181,7 +181,7 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: local_func_var = relax.Var(node.name, relax.FuncStructInfo(params_sinfo, ret_sinfo)) self.var_table.add(node.name, local_func_var) - purity = find_decorator_annotation(node, "pure") + purity = find_decorator_annotation(node, "pure", default=None) # treat the function as private if we are inside another function # or if it has a privacy annotation privacy = is_inner_function or find_decorator_annotation(node, "private", default=False) @@ -367,7 +367,6 @@ def visit_if(self: Parser, node: doc.If) -> None: @dispatch.register(token="relax", type_name="enter_token") def enter_token(self: Parser) -> Dict[str, Any]: def relax_call(self, *args) -> Expr: - args = [convert_to_expr(arg) if isinstance(arg, tuple) else arg for arg in args] if all(isinstance(x, Expr) for x in args): diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 1bc7267af6ca..d622f973aa73 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -441,8 +441,8 @@ TVM_REGISTER_GLOBAL("relax.SeqExpr") TVM_REGISTER_NODE_TYPE(FunctionNode); -Function::Function(Array params, Expr body, Optional ret_struct_info, bool is_pure, - DictAttrs attrs, Span span) { +Function::Function(Array params, Expr body, Optional ret_struct_info, + Optional is_pure_override, DictAttrs attrs, Span span) { // Set the function type. // For function, we take a conservative approach and require the function type // to be known at construction time. @@ -473,6 +473,23 @@ Function::Function(Array params, Expr body, Optional ret_struct ret_struct_info = body_sinfo; } + bool is_pure = [&]() -> bool { + if (is_pure_override.defined()) { + return is_pure_override.value()->value; + } + + if (attrs.defined() && attrs->dict.defined()) { + if (auto opt_force_pure = attrs->dict.Get(relax::attr::kForcePure)) { + bool force_pure = Downcast(opt_force_pure.value())->value; + if (force_pure) { + return true; + } + } + } + + return !ContainsImpureCall(body); + }(); + FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure); // set the fields @@ -490,7 +507,7 @@ Function::Function(Array params, Expr body, Optional ret_struct TVM_REGISTER_GLOBAL("relax.Function") .set_body_typed([](Array params, Expr body, Optional ret_struct_info, - bool is_pure, DictAttrs attrs, Span span) { + Optional is_pure, DictAttrs attrs, Span span) { return Function(params, body, ret_struct_info, is_pure, attrs, span); }); diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index b95db57a881b..480db363c820 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -66,7 +66,7 @@ void FunctionFrameNode::ExitWithScope() { tvm::relax::Function func(/*params=*/params, /*body=*/body, /*ret_struct_info=*/ret_struct_info, - /*is_pure=*/is_pure.value_or(Bool(true))->value, + /*is_pure=*/is_pure, /*attrs=*/dict_attrs); // Step 2: Update IRModule. if (builder->frames.empty()) { diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 285a3a348e3b..b56934b7522f 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -52,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) /////////////////////////////// Function //////////////////////////////// -FunctionFrame Function(const Bool& is_pure, const Bool& is_private) { +FunctionFrame Function(const Optional& is_pure, const Bool& is_private) { ObjectPtr n = make_object(); const IRBuilder& ir_builder = IRBuilder::Current(); Optional mod = NullOpt; diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index bbf38d8c386b..09b8b881848f 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -430,6 +430,7 @@ def test_inline_prim_func(): y, ), R.Tensor(ndim=0, dtype="int32"), + is_pure=True, ).with_attr("global_symbol", "foo") new_mod = tvm.IRModule.from_expr(new_func) assert not rx.analysis.well_formed(new_mod, check_struct_info=False) @@ -563,11 +564,45 @@ def test_unlabeled_impure(): x = rx.Var("x", R.Tensor((), dtype="int32")) y = rx.Var("y") block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))]) - # print is impure, but the function is not labeled as impure + # Calls to `relax.op.print` are impure, but the function is not + # explicitly labeled as impure. If there is no user-specified + # purity annotation, the function's purity is inferred from the + # body. func = rx.Function([x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32")).with_attr( "global_symbol", "foo" ) + assert not func.is_pure mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert rx.analysis.well_formed(mod) + + +def test_unlabeled_pure(): + x = rx.Var("x", R.Tensor((), dtype="int32")) + # There are no calls to impure functions, so the `relax::Function` + # constructor infers that the function is pure. + func = rx.Function([x], rx.SeqExpr([], x), R.Tensor((), dtype="int32")).with_attr( + "global_symbol", "foo" + ) + assert func.is_pure + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + assert rx.analysis.well_formed(mod) + + +def test_labeled_pure(): + x = rx.Var("x", R.Tensor((), dtype="int32")) + y = rx.Var("y") + block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))]) + # Calls to `relax.op.print` are impure, but the function is + # explicitly labeled as pure. + func = rx.Function( + [x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32"), is_pure=True + ).with_attr("global_symbol", "foo") + # The explicit argument is used for the function's purity. + assert func.is_pure + mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) + + # The function is ill-formed, as the function's purity does not + # match the explicit annotation. assert not rx.analysis.well_formed(mod) @@ -576,10 +611,11 @@ def test_labeled_impure(): x = rx.Var("x", R.Tensor((), dtype="int32")) y = rx.Var("y") block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))]) - # print is impure, but the function is not labeled as impure + # print is impure, and the function is labeled as impure. func = rx.Function( [x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32"), is_pure=False ).with_attrs({"global_symbol": "foo"}) + assert not func.is_pure mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) assert rx.analysis.well_formed(mod) @@ -589,9 +625,13 @@ def test_force_pure(): y = rx.Var("y") block = rx.BindingBlock([rx.VarBinding(y, rx.op.print(x, format="{}"))]) # print is impure, but force_pure overrides the judgment - func = rx.Function([x], rx.SeqExpr([block], x), R.Tensor((), dtype="int32")).with_attrs( - {"global_symbol": "foo", "relax.force_pure": True} + func = rx.Function( + [x], + rx.SeqExpr([block], x), + R.Tensor((), dtype="int32"), + attrs={"global_symbol": "foo", "relax.force_pure": True}, ) + assert func.is_pure mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func)) assert rx.analysis.well_formed(mod) diff --git a/tests/python/relax/test_frontend_nn_debug.py b/tests/python/relax/test_frontend_nn_debug.py index a055631a4d51..eca6c6f3c082 100644 --- a/tests/python/relax/test_frontend_nn_debug.py +++ b/tests/python/relax/test_frontend_nn_debug.py @@ -24,6 +24,8 @@ from tvm.relax.frontend.nn import op, spec from tvm.runtime import NDArray +from tvm.script import ir as I, relax as R + def test_debug_print(): class Layer(nn.Module): @@ -42,6 +44,62 @@ def forward(self, x: nn.Tensor): # pylint: disable=invalid-name assert isinstance(y, torch.Tensor) +def test_debug_print_well_formed(): + class Layer(nn.Module): + def forward(self, state: nn.Tensor): + state = state * 2.0 + op.print_(state) + state = state * 2.0 + return state + + forward_code = Layer.forward.__wrapped__.__code__ + debug_location = f"{forward_code.co_filename}:{forward_code.co_firstlineno+2}" + + model, _ = Layer().export_tvm( + spec={ + "forward": {"state": spec.Tensor([10, 5], dtype="float32")}, + }, + debug=True, + ) + + @I.ir_module + class Expected: + @R.function + def _initialize_effect() -> R.Tuple(R.Object): + with R.dataflow(): + _io = R.null_value() + gv = (_io,) + R.output(gv) + return gv + + @R.function(pure=False) + def forward( + state: R.Tensor((10, 5), dtype="float32"), _io: R.Object + ) -> R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tuple(R.Object)): + R.func_attr({"num_input": 2}) + with R.dataflow(): + mul = R.multiply(state, R.const(2, "float32")) + R.output(mul) + + _io1 = R.call_packed( + "vm.builtin.invoke_debug_func", + _io, + R.str("vm.builtin.debug_print"), + R.str(debug_location), + mul, + sinfo_args=(R.Object,), + ) + + with R.dataflow(): + mul1 = R.multiply(mul, R.const(2, "float32")) + gv1 = mul1, (_io1,) + R.output(gv1) + + return gv1 + + tvm.ir.assert_structural_equal(Expected, model) + + def test_debug_func(): @tvm.register_func("testing.relax.frontend.nn.test_debug_func") def _debug( # pylint: disable=too-many-arguments @@ -79,5 +137,4 @@ def forward(self, x: nn.Tensor, v: tir.Var): # pylint: disable=invalid-name if __name__ == "__main__": - test_debug_print() - test_debug_func() + tvm.testing.main() diff --git a/tests/python/relax/test_frontend_nn_extern_module.py b/tests/python/relax/test_frontend_nn_extern_module.py index 6eaf1fbfc805..fc6d491db3f2 100644 --- a/tests/python/relax/test_frontend_nn_extern_module.py +++ b/tests/python/relax/test_frontend_nn_extern_module.py @@ -91,10 +91,9 @@ def scalar_add( ) -> R.Tensor((), dtype="float32"): R.func_attr({"num_input": 2}) with R.dataflow(): - ext_scalar_add = R.call_dps_packed( + gv = R.call_dps_packed( "ext_scalar_add", (a, b), out_sinfo=R.Tensor((), dtype="float32") ) - gv: R.Tensor((), dtype="float32") = ext_scalar_add R.output(gv) return gv @@ -107,10 +106,9 @@ def test_sym( z = T.int64() R.func_attr({"num_input": 2}) with R.dataflow(): - ext_test_sym = R.call_dps_packed( + gv1 = R.call_dps_packed( "ext_test_sym", (a, b), out_sinfo=R.Tensor((x, y, z, 9), dtype="float32") ) - gv1: R.Tensor((x, y, z, 9), dtype="float32") = ext_test_sym R.output(gv1) return gv1 diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index 9b357114d351..743de3089548 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -491,8 +491,7 @@ def _initialize_effect() -> R.Tuple(R.Object, R.Object): R.prim_value(0), sinfo_args=(R.Object,), ) - lv1 = _io, cache - gv = lv1 + gv = _io, cache R.output(gv) return gv diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index eb1df67a8f81..239ef3e5610b 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -532,8 +532,7 @@ def add_one(A: T.Buffer((T.int64(10), T.int64(10)), "float32"), T_add: T.Buffer( def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv @@ -605,8 +604,7 @@ def llama_fused_rope(var_qkv: T.handle, offset: T.int64, var_q: T.handle, var_k: def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv @@ -693,8 +691,7 @@ def inplace_take( def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv @@ -711,13 +708,12 @@ def test( R.func_attr({"num_input": 4}) cls = Expected with R.dataflow(): - lv1 = R.call_tir( + gv1 = R.call_tir( cls.inplace_take, (embedding_table, input_ids, embedding_dst), out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype), tir_vars=R.shape([offset_1]), ) - gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1 R.output(gv1) return gv1 @@ -766,8 +762,7 @@ def test(A: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16, 16), dtype="fl R.func_attr({"num_input": 1}) cls = Expected with R.dataflow(): - lv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32")) - gv: R.Tensor((16, 16), dtype="float32") = lv + gv = R.call_tir(cls.tir_func, (A,), out_sinfo=R.Tensor((16, 16), dtype="float32")) R.output(gv) return gv @@ -794,8 +789,7 @@ class Expected: def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv @@ -845,7 +839,6 @@ def test(self): @tvm.testing.requires_gpu def test_multinomial_from_uniform(): - prob_shape = (3, 5) sample_shape = (6, 1) @@ -882,8 +875,7 @@ def get_sample_index(A: T.handle, B: T.handle, C: T.handle, D: T.handle): def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv @@ -1009,8 +1001,7 @@ def get_renorm_prob(A: T.handle, B: T.handle, C: T.handle, D: T.handle): def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv @@ -1124,8 +1115,7 @@ def get_renorm_cutoff(A: T.handle, B: T.handle, C: T.handle, D: T.handle, E: T.h def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv diff --git a/tests/python/relax/test_frontend_nn_packing.py b/tests/python/relax/test_frontend_nn_packing.py index 56b614a807b8..c2cc22c17d40 100644 --- a/tests/python/relax/test_frontend_nn_packing.py +++ b/tests/python/relax/test_frontend_nn_packing.py @@ -59,8 +59,7 @@ def forward( matmul = R.matmul(x, matmul_1_weight) matmul_2_weight = R.permute_dims(linear_2_weight) matmul1 = R.matmul(x, matmul_2_weight) - add = R.add(matmul, matmul1) - gv = add + gv = R.add(matmul, matmul1) R.output(gv) return gv diff --git a/tests/python/relax/test_frontend_nn_subroutines.py b/tests/python/relax/test_frontend_nn_subroutines.py index 6bbf57aeadde..32ae967916a8 100644 --- a/tests/python/relax/test_frontend_nn_subroutines.py +++ b/tests/python/relax/test_frontend_nn_subroutines.py @@ -61,8 +61,7 @@ def forward( def _initialize_effect() -> R.Tuple(R.Object): with R.dataflow(): _io: R.Object = R.null_value() - lv: R.Tuple(R.Object) = (_io,) - gv: R.Tuple(R.Object) = lv + gv = (_io,) R.output(gv) return gv @@ -75,9 +74,8 @@ def layer( with R.dataflow(): state = R.matmul(state, weights) state = Expected.activation(state) - dataflow_output = state - R.output(dataflow_output) - return dataflow_output + R.output(state) + return state @R.function(private=True) def activation( @@ -85,9 +83,8 @@ def activation( ) -> R.Tensor(("batch_size", 32), dtype="float32"): with R.dataflow(): state = R.nn.silu(state) - dataflow_output = state - R.output(dataflow_output) - return dataflow_output + R.output(state) + return state mod = Layer(64, 32) batch_size = tvm.tir.Var("batch_size", "int64")