diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index fb401e1b6787..f55b93ff3d3a 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -149,7 +149,7 @@ class LazyOutputMutator : public ExprMutator { Var fset_output("fset_output", FuncStructInfo({PrimStructInfo(DataType::Int(64)), ObjectStructInfo()}, - TupleStructInfo(Array{}))); + TupleStructInfo(Array{}), /* purity = */ false)); plan_ = FunctionPlan{std::move(output_lookup), fset_output}; std::optional num_input_params = GetNumInputParams(func); @@ -189,6 +189,7 @@ class LazyOutputMutator : public ExprMutator { auto write_ptr = node.CopyOnWrite(); write_ptr->params = new_params; write_ptr->body = new_body; + write_ptr->is_pure = false; } if (num_input_params.has_value()) { node = WithAttr(node, attr::kNumInput, Integer(num_input_params.value() + 1)); diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 040aea28909d..278ac825f7a7 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -1002,11 +1002,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl @I.ir_module class Expected: - @R.function + @R.function(pure=False) def transform_params( A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32"), - fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])), + fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False), ): C = R.multiply(A, R.const(2, "float32")) fset_output(R.prim_value(1), C) @@ -1036,11 +1036,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl @I.ir_module class Expected: - @R.function + @R.function(pure=False) def transform_params( A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32"), - fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])), + fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False), ): fset_output(R.prim_value(1), B) C = R.multiply(A, R.const(2, "float32")) @@ -1070,10 +1070,10 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl @I.ir_module class Expected: - @R.function + @R.function(pure=False) def transform_params( A: R.Tensor([16, 16], "float32"), - fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])), + fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False), B: R.Tensor([16, 16], "float32"), ): R.func_attr({"num_input": 2}) @@ -1105,11 +1105,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl @I.ir_module class Expected: - @R.function + @R.function(pure=False) def transform_params( A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32"), - fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])), + fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False), ): C = R.multiply(A, R.const(2, "float32")) D = R.add(C, B) @@ -1140,11 +1140,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl @I.ir_module class Expected: - @R.function + @R.function(pure=False) def transform_params( A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32"), - fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])), + fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False), ): C = R.multiply(A, R.const(2, "float32")) fset_output(R.prim_value(0), C) @@ -1171,11 +1171,11 @@ def transform_params(A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "fl @I.ir_module class Expected: - @R.function + @R.function(pure=False) def transform_params( A: R.Tensor([16, 16], "float32"), B: R.Tensor([16, 16], "float32"), - fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([])), + fset_output: R.Callable([R.Prim("int64"), R.Object], R.Tuple([]), purity=False), ): C = R.multiply(A, R.const(2, "float32")) D = R.add(C, B)