From 16004415a6d6c7ac6e643d3a506aee8b66f616d9 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 21 Jun 2023 18:12:11 -0400 Subject: [PATCH 01/26] Use privacy annotation to decide whether to include global_symbol --- python/tvm/script/parser/relax/entry.py | 6 +- python/tvm/script/parser/relax/parser.py | 29 +++--- src/relax/transform/gradient.cc | 4 +- src/relax/transform/lift_transform_params.cc | 5 +- src/script/ir_builder/relax/frame.cc | 18 ++++ src/script/printer/relax/function.cc | 40 ++++++-- .../test_transform_attach_global_symbol.py | 6 +- .../test_transform_combine_parallel_matmul.py | 3 + .../test_transform_dead_code_elimination.py | 8 +- tests/python/relax/test_transform_fuse_ops.py | 14 +-- .../test_transform_fuse_ops_by_pattern.py | 29 +++--- tests/python/relax/test_transform_fuse_tir.py | 4 +- tests/python/relax/test_transform_gradient.py | 34 +++---- .../relax/test_transform_lambda_lift.py | 14 +-- ...est_transform_merge_composite_functions.py | 47 ++++----- .../test_transform_rewrite_cuda_graph.py | 8 +- tests/python/relax/test_tvmscript_parser.py | 29 +++++- .../relax/test_tvmscript_printer_relax.py | 99 +++++++++++++++++++ 18 files changed, 290 insertions(+), 107 deletions(-) diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index 2711e855dddf..ff237a5600e7 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -45,9 +45,11 @@ # this formulation allows us to support having @R.function # appear as a decorator by itself or to have optional arguments # like @R.function(pure=False) -def function(f: Optional[FType] = None, pure: bool = True) -> Union[Function, FType]: +def function( + f: Optional[FType] = None, pure: bool = True, private: bool = False +) -> Union[Function, FType]: # pylint: disable=unused-argument - # (pure isn't used here, but is used later in parsing) + # (pure and private aren't used here, but are used later in parsing) # need to inspect the stack first because is_defined_in_class expects the outer class # to be in a particular position in the stack diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 427c56bcc8e3..d69a841d4256 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -21,7 +21,7 @@ from typing import Any, Dict, Optional from tvm import relax, tir -from tvm.ir import GlobalVar, structural_equal +from tvm.ir import GlobalVar, make_node, structural_equal from tvm.relax import Expr, StructInfo from tvm.relax.utils import convert_to_expr from tvm.script.ir_builder.relax.frame import BlockFrame @@ -178,7 +178,8 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: local_func_var = relax.Var(node.name, relax.FuncStructInfo(params_sinfo, ret_sinfo)) self.var_table.add(node.name, local_func_var) - purity = find_purity_annotation(node) + purity = find_decorator_annotation(node, "pure") + # don't handle the privacy annotation here because it's only relevant for global funcs with self.var_table.with_frame(): with self.with_dispatch_token("relax"): @@ -204,20 +205,19 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: self.visit_body(node.body) -def find_purity_annotation(node: doc.FunctionDef) -> bool: +def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: bool = True) -> bool: """ - Check the value of `pure` in the function decorator. - Returns the annotated purity if present, otherwise defaulting to True. - This allows for specifying the purity in the function signature. + Check the value of given annotation (argument name) in the function decorator. + Returns the value of the annotation if present, otherwise giving the default value. """ - # look for the pure argument in the function decorator + # look for the named argument in the function decorator for dec in node.decorator_list: if not isinstance(dec, doc.Call) or dec.func.attr != "function": continue for keyword in dec.keywords: - if keyword.arg == "pure": + if keyword.arg == annotation: return keyword.value.value - return True + return default @dispatch.register(token="relax", type_name="tvm_declare_function") @@ -238,8 +238,15 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar param_sinfo = eval_struct_info(self, arg.annotation, eval_str=True) params.append(relax.Var(arg.arg, param_sinfo)) - is_pure = find_purity_annotation(node) - func_signature = relax.Function.create_empty(params, ret_sinfo, is_pure=is_pure) + is_pure = find_decorator_annotation(node, "pure") + + # if the global function is not private, then use its name as the global symbol + is_private = find_decorator_annotation(node, "private", default=False) + attrs = None + if not is_private: + attrs = make_node("DictAttrs", global_symbol=node.name) + + func_signature = relax.Function.create_empty(params, ret_sinfo, is_pure=is_pure, attrs=attrs) return I.decl_function(node.name, func_signature) diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 7645ae8cb6c6..65fb3b96d6c3 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -333,7 +333,9 @@ class GradientMutator : private ExprMutator { } GradientMutator mutator(mod, require_grads_value, target_index); - Function new_func_transformed = Downcast(mutator.VisitExpr(new_func)); + // remove the global symbol if the original had one (the adjoint does not need a global symbol) + Function new_func_transformed = + WithoutAttr(Downcast(mutator.VisitExpr(new_func)), tvm::attr::kGlobalSymbol); IRModule new_module = GetRef(mod.CopyOnWrite()); new_module->Add(GlobalVar(func_name + "_adjoint"), new_func_transformed); diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index f7c9a4189dbb..fb1f2927769d 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -251,7 +251,10 @@ class TransformParamsLifter : public ExprMutator { lift_plan_ = planner.Plan(func, num_input); // Step 2: Add the lifted function to the module - builder_->AddFunction(lift_plan_.f_transform_params, new_func_name); + // (The lifted function should be public so we add a global symbol to it) + auto lift_func = + WithAttr(lift_plan_.f_transform_params, tvm::attr::kGlobalSymbol, new_func_name); + builder_->AddFunction(lift_func, new_func_name); // Step 3: Update the current function. diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 00bbd2a551a6..67114bd7101b 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -74,6 +74,24 @@ void FunctionFrameNode::ExitWithScope() { "function scope, if it's defined in a Module"; const IRModuleFrame& frame = opt_frame.value(); const String& func_name = name.value_or(""); + // If the function has already been declared (i.e., it is global), see if there is + // already a global symbol defined for it (i.e., it is not private). + // If yes, add it to the current function's attributes (unless one was manually defined) + if (frame->global_var_map.count(func_name)) { + auto decl = frame->functions.at(frame->global_var_map.at(func_name)); + if (decl->attrs.defined()) { + auto attr_dict = decl->attrs.get()->dict; + if (attr_dict.count("global_symbol") && !attrs.count("global_symbol")) { + Map new_attrs; + for (auto kv : attrs) { + new_attrs.Set(kv.first, kv.second); + } + new_attrs.Set("global_symbol", attr_dict.at("global_symbol")); + auto mut_f = func.CopyOnWrite(); + mut_f->attrs = DictAttrs(new_attrs); + } + } + } if (!frame->global_var_map.count(func_name)) { // First time visiting the function. ir::DeclFunction(func_name, func); diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index bd5d969563a6..a42685dcf828 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -52,17 +52,45 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) (*f)->func_vars = nullptr; // Step 4. Print attributes if (n->attrs.defined() && !n->attrs->dict.empty()) { - (*f)->stmts.push_back( - ExprStmtDoc(Relax(d, "func_attr") // - ->Call({d->AsDoc(n->attrs, n_p->Attr("attrs"))}))); + // if the function is a global function and has a global symbol, + // then don't print the global symbol (it will be implicit from not being private) + if (d->frames.size() == 3 && n->attrs->dict.count("global_symbol")) { + Map new_attrs; + for (auto kv : n->attrs->dict) { + if (kv.first != "global_symbol") { + new_attrs.Set(kv.first, kv.second); + } + } + if (!new_attrs.empty()) { + (*f)->stmts.push_back(ExprStmtDoc( + Relax(d, "func_attr") // + ->Call({d->AsDoc(DictAttrs(new_attrs), n_p->Attr("attrs"))}))); + } + } else { + (*f)->stmts.push_back( + ExprStmtDoc(Relax(d, "func_attr") // + ->Call({d->AsDoc(n->attrs, n_p->Attr("attrs"))}))); + } } // Step 5. Prepare the decorator (include purity if it's impure) ExprDoc decorator = Relax(d, "function"); + Array pos_args = {}; + Array dec_keys; + Array dec_values; if (!n->is_pure) { - Array pos_args = {}; - decorator = std::move(decorator->Call( - pos_args, {"pure"}, {LiteralDoc::Boolean(false, Optional())})); + dec_keys.push_back("pure"); + dec_values.push_back(LiteralDoc::Boolean(false, Optional())); + } + // if the function is global and does not have a global symbol, indicate that it's private + if (d->frames.size() == 3 && + (!n->attrs.defined() || !n->attrs->dict.count("global_symbol"))) { + dec_keys.push_back("private"); + dec_values.push_back(LiteralDoc::Boolean(true, Optional())); } + if (dec_keys.size()) { + decorator = std::move(decorator->Call(pos_args, dec_keys, dec_values)); + } + // Step 6. Print body Array body = PrintSeqExpr(Downcast(n->body), n_p->Attr("body"), d, /*use_ret=*/true); diff --git a/tests/python/relax/test_transform_attach_global_symbol.py b/tests/python/relax/test_transform_attach_global_symbol.py index 035e21609daa..680df969474a 100644 --- a/tests/python/relax/test_transform_attach_global_symbol.py +++ b/tests/python/relax/test_transform_attach_global_symbol.py @@ -43,7 +43,7 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: C[vi, vj] = T.float32(0) C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - @R.function + @R.function(private=True) def main( x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") ) -> R.Tensor: @@ -74,7 +74,6 @@ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: def main( x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") ) -> R.Tensor: - R.func_attr({"global_symbol": "main"}) m, n, k = T.int64(), T.int64(), T.int64() gv0 = R.call_tir(Expected.tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) return gv0 @@ -94,7 +93,7 @@ class Before: def tir_zeros(x: T.Buffer((2), "float32")) -> None: x[0] = T.float32(0) - @R.function + @R.function(private=True) def main() -> R.Tensor: gv0 = R.call_tir(Before.tir_zeros, (), R.Tensor((2,), dtype="float32")) return gv0 @@ -110,7 +109,6 @@ def tir_zeros(x: T.Buffer((2), "float32")) -> None: @R.function def main() -> R.Tensor: - R.func_attr({"global_symbol": "main"}) gv0 = R.call_tir(Expected.tir_zeros, (), R.Tensor((2,), dtype="float32")) return gv0 diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py b/tests/python/relax/test_transform_combine_parallel_matmul.py index 719daaf4496d..66861f4ed485 100644 --- a/tests/python/relax/test_transform_combine_parallel_matmul.py +++ b/tests/python/relax/test_transform_combine_parallel_matmul.py @@ -378,6 +378,7 @@ def expected1( w2: R.Tensor((2, 640, 640), dtype="float32"), w3: R.Tensor((3, 4, 640, 640), dtype="float32"), ) -> R.Tensor: + R.func_attr({"global_symbol": "main"}) with R.dataflow(): lv = R.concat((w0, w2), axis=2) lv1 = R.matmul(x, lv, out_dtype="float32") @@ -458,6 +459,7 @@ def expected1( b0: R.Tensor((640,), dtype="float32"), b1: R.Tensor((640,), dtype="float32"), ) -> R.Tensor: + R.func_attr({"global_symbol": "main"}) with R.dataflow(): lv = R.concat((w0, w1, w2), axis=1) lv1 = R.matmul(x1, lv, out_dtype="float32") @@ -515,6 +517,7 @@ def expected( w3: R.Tensor((640, 640), dtype="float32"), w4: R.Tensor((640, 640), dtype="float32"), ) -> R.Tensor: + R.func_attr({"global_symbol": "main"}) with R.dataflow(): lv = R.concat((w0, w1, w2), axis=1) lv1 = R.matmul(x1, lv, out_dtype="float32") diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 9c6e0e0567fe..12a3de6acb30 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -168,7 +168,7 @@ def tir_add( vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = x[vi, vj] + y[vi, vj] - @R.function + @R.function(private=True) def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): gv0 = R.add(x, w) return gv0 @@ -202,7 +202,7 @@ def tir_add( vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = x[vi, vj] + y[vi, vj] - @R.function + @R.function(private=True) def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): gv0 = R.add(x, w) return gv0 @@ -239,7 +239,7 @@ def tir_add( vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = x[vi, vj] + y[vi, vj] - @R.function + @R.function(private=True) def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): gv0 = R.add(x, w) return gv0 @@ -310,7 +310,7 @@ def unused_func1( vi, vj = T.axis.remap("SS", [i, j]) z[vi, vj] = x[vi, vj] + y[vi, vj] - @R.function + @R.function(private=True) def unused_func2(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")): gv0 = R.add(x, w) return gv0 diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 169539b07243..6a30aabf3570 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -938,7 +938,7 @@ def relu(A: T.Buffer((T.int64(1), T.int64(512), T.int64(64), T.int64(64)), "floa T.writes(B[v_i0, v_i1, v_i2, v_i3]) B[v_i0, v_i1, v_i2, v_i3] = T.max(A[v_i0, v_i1, v_i2, v_i3], T.float32(0)) - @R.function + @R.function(private=True) def fused_layer_norm_relu(x: R.Tensor((1, 512, 64, 64), dtype="float32"), mean: R.Tensor((64, 64), dtype="float32"), var: R.Tensor((64, 64), dtype="float32")) -> R.Tensor((1, 512, 64, 64), dtype="float32"): R.func_attr({"Primitive": 1}) cls = Expected @@ -1080,7 +1080,7 @@ def transpose(rxplaceholder: T.Buffer((T.int64(320), T.int64(1280)), "float32"), T.writes(T_transpose[v_ax0, v_ax1]) T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0] - @R.function + @R.function(private=True) def fused_conv2d_add_add2(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), w1: R.Tensor((320, 320, 3, 3), dtype="float32"), lv28: R.Tensor((1, 320, 1, 1), dtype="float32"), lv35: R.Tensor((2, 320, 1, 1), dtype="float32")) -> R.Tensor((2, 320, 64, 64), dtype="float32"): R.func_attr({"Primitive": 1}) cls = Expected @@ -1091,7 +1091,7 @@ def fused_conv2d_add_add2(inp_0: R.Tensor((2, 320, 64, 64), dtype="float32"), w1 R.output(gv) return gv - @R.function + @R.function(private=True) def fused_matmul_add1(inp_1: R.Tensor((2, 1280), dtype="float32"), lv31: R.Tensor((1280, 320), dtype="float32"), b2: R.Tensor((320,), dtype="float32")) -> R.Tensor((2, 320), dtype="float32"): cls = Expected R.func_attr({"Primitive": 1}) @@ -1226,7 +1226,7 @@ def transpose1(rxplaceholder: T.Buffer((T.int64(10), T.int64(128)), "float32"), T.writes(T_transpose[v_ax0, v_ax1]) T_transpose[v_ax0, v_ax1] = rxplaceholder[v_ax1, v_ax0] - @R.function + @R.function(private=True) def fused_matmul1_add1(inp_1: R.Tensor((1, 128), dtype="float32"), lv4: R.Tensor((128, 10), dtype="float32"), linear2_bias: R.Tensor((10,), dtype="float32")) -> R.Tensor((1, 10), dtype="float32"): R.func_attr({"Primitive": 1}) cls = Expected @@ -1268,7 +1268,7 @@ def main(x: R.Tensor(["n", "m"], "float32")): @I.ir_module class Expected: - @R.function + @R.function(private=True) def fused_add_exp_squeeze( x: R.Tensor(["n", "m"], "float32"), p0: R.Tensor([], "float32") ) -> R.Tensor(["n", "m"], dtype="float32"): @@ -1306,7 +1306,7 @@ def main(s: R.Shape(["n"])): @I.ir_module class Expected: - @R.function + @R.function(private=True) def fused_full_trilu_broadcast_to( s: R.Shape(["n"]), ) -> R.Tensor([1, 1, "n", "n"], "float32"): @@ -1354,7 +1354,7 @@ def main(s: R.Shape(["n"]), kv_cache: R.Object): @I.ir_module class Expected: - @R.function + @R.function(private=True) def fused_full_trilu_broadcast_to( s: R.Shape(["n"]), ) -> R.Tensor([1, 1, "n", "n"], "float32"): diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 5fb2b3332c23..592132516bee 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -60,7 +60,7 @@ def main( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu_dnnl( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -124,7 +124,7 @@ def main( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -138,7 +138,7 @@ def fused_relax_nn_conv2d_relax_nn_relu( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu1( conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -174,7 +174,7 @@ def main( R.output(conv2d) return conv2d - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -187,7 +187,7 @@ def fused_relax_nn_conv2d( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_conv2d1( conv11: R.Tensor((1, 64, 56, 56), dtype="float32"), weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -236,7 +236,7 @@ def main( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu( conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -250,7 +250,7 @@ def fused_relax_nn_conv2d_relax_nn_relu( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -303,7 +303,7 @@ def main( R.output(out) return out - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -377,7 +377,7 @@ def main( @tvm.script.ir_module class Conv2dx2_partitioned: - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_cutlass( data: R.Tensor((16, 32, 32, 16), dtype="float16"), weight1: R.Tensor((16, 3, 3, 16), dtype="float16"), @@ -565,7 +565,7 @@ def relu( T.writes(out[i, j, k, l]) out[i, j, k, l] = T.max(data[i, j, k, l], T.float32(0)) - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -617,7 +617,7 @@ def main( @I.ir_module class Conv2dReLU_partitioned: - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -685,7 +685,7 @@ def main( @I.ir_module class Conv2dWithConstantWeight_partitioned: - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data: R.Tensor((1, 64, 56, 56), dtype="float32"), param_0: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -721,6 +721,7 @@ def main( def test_split(): @R.function def func(inp: R.Tensor((16, 32), "float32")): + R.func_attr({"global_symbol": "main"}) with R.dataflow(): tup = R.split(inp, [16], axis=1) out = R.add(tup[0], tup[1]) @@ -729,7 +730,7 @@ def func(inp: R.Tensor((16, 32), "float32")): @tvm.script.ir_module class Expected1: - @R.function + @R.function(private=True) def fused_relax_split( inp: R.Tensor((16, 32), dtype="float32") ) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), dtype="float32")): @@ -756,7 +757,7 @@ def main(inp: R.Tensor((16, 32), dtype="float32")) -> R.Tensor((16, 16), dtype=" @I.ir_module class Expected2: - @R.function + @R.function(private=True) def fused_relax_split_relax_add( inp: R.Tensor((16, 32), dtype="float32") ) -> R.Tensor((16, 16), dtype="float32"): diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 00dc7146541b..1aca1a1b583f 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -603,7 +603,7 @@ def before(): @I.ir_module class Expected: - @R.function + @R.function(private=True) def func1(x: R.Tensor((10, 20), dtype="float32")) -> R.Tensor((10, 20), dtype="float32"): with R.dataflow(): gv2 = R.call_tir( @@ -614,7 +614,7 @@ def func1(x: R.Tensor((10, 20), dtype="float32")) -> R.Tensor((10, 20), dtype="f R.output(gv2) return gv2 - @R.function + @R.function(private=True) def func2(x: R.Tensor((20, 10), dtype="float32")) -> R.Tensor((20, 10), dtype="float32"): with R.dataflow(): gv3 = R.call_tir( diff --git a/tests/python/relax/test_transform_gradient.py b/tests/python/relax/test_transform_gradient.py index 50063fe385bb..d84d76291959 100644 --- a/tests/python/relax/test_transform_gradient.py +++ b/tests/python/relax/test_transform_gradient.py @@ -44,7 +44,7 @@ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor(None, "float32", ndim=0): R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor(None, "float32", ndim=0),R.Tuple(R.Tensor(None, "float32", ndim=2)),): with R.dataflow(): gv: R.Tensor((), "float32") = R.sum(x, axis=None, keepdims=False) @@ -83,7 +83,7 @@ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tensor((3, 3), "float32") = x @@ -125,7 +125,7 @@ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tensor((3, 3), "float32") = R.add(x, x) @@ -168,7 +168,7 @@ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tensor((3, 3), "float32") = R.add(x, x) @@ -217,7 +217,7 @@ def main( R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tensor((3, 3), "float32") = R.add(x, y) @@ -249,7 +249,7 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32"), z: R.Te R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))): # block 0 with R.dataflow(): @@ -292,7 +292,7 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tu R.output(lv1, lv2, lv3) return (lv1, lv2, lv3) - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((), "float32"), R.Tensor((), "float32")), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tensor((3, 3), "float32") = x @@ -341,7 +341,7 @@ def main(x: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), y R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (y, z) @@ -399,7 +399,7 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Te R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): # block 0 with R.dataflow(): @@ -473,7 +473,7 @@ def main(x: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")) = ((y, z), u) @@ -552,7 +552,7 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Te R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): with R.dataflow(): lv0: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (x, y) @@ -621,7 +621,7 @@ def main(x: R.Tensor((6,), "float32")) -> R.Tensor((), "float32"): R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((6,), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((6,), "float32"))): with R.dataflow(): lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.split(x, indices_or_sections=2, axis=0) @@ -671,7 +671,7 @@ def main(x: R.Tensor((3,), "float32"), y: R.Tuple(R.Tensor((3,), "float32"), R.T R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3,), "float32"), y: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32"))) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3,), "float32"), R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")))): with R.dataflow(): lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = (x, x) @@ -740,7 +740,7 @@ def main(x: R.Tensor((3,), "float32")) -> R.Tensor((), "float32"): R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3,), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3,), "float32"))): # block 0 with R.dataflow(): @@ -806,7 +806,7 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Te R.output(gv) return gv - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tensor((3, 3), "float32") = R.add(x, cst) @@ -1070,7 +1070,7 @@ def main(x: R.Tensor((3, 4), "float32")): @I.ir_module class Expected: - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3, 4), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 4), dtype="float32"))): with R.dataflow(): s: R.Shape([3, 2, 2]) = R.shape([3, 2, 2]) @@ -1122,7 +1122,7 @@ def main( @I.ir_module class Expected: - @R.function + @R.function(private=True) def main_adjoint(x: R.Tensor((3, 10), dtype="float32"), w0: R.Tensor((10, 5), dtype="float32"), b0: R.Tensor((5,), dtype="float32"), label: R.Tensor((3, 5), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((5,), dtype="float32"))): with R.dataflow(): lv0: R.Tensor((3, 5), dtype="float32") = R.matmul(x, w0, out_dtype="void") diff --git a/tests/python/relax/test_transform_lambda_lift.py b/tests/python/relax/test_transform_lambda_lift.py index ddc274fee272..d67248417173 100644 --- a/tests/python/relax/test_transform_lambda_lift.py +++ b/tests/python/relax/test_transform_lambda_lift.py @@ -42,7 +42,7 @@ def test_basic(): # the target IRModule @tvm.script.ir_module class Expected: - @R.function + @R.function(private=True) def lifted_func_0( x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): @@ -97,12 +97,12 @@ def main( ) return res - @R.function + @R.function(private=True) def lifted_func_1(x1: R.Tensor((2, 3), "float32"), c1: R.Tensor((2, 3), "float32")): r_1: R.Tensor((2, 3), "float32") = R.add(x1, c1) return r_1 - @R.function + @R.function(private=True) def lifted_func_0(y: R.Tensor((2, 3), "float32")) -> R.Object: inner_func = R.make_closure(Expected.lifted_func_1, (y,)) return inner_func @@ -140,7 +140,7 @@ def test_recursive(): # the expected IRModule @tvm.script.ir_module class Expected: - @R.function + @R.function(private=True) def lifted_func_0( i: R.Tensor((), "int32"), s: R.Tensor((2, 3), "float32"), x: R.Tensor((2, 3), "float32") ) -> R.Tensor((2, 3), "float32"): @@ -224,14 +224,14 @@ def glob_func_2( gv11: R.Tensor((10, 5), "float32") = inner(x11, y11) return gv11 - @R.function + @R.function(private=True) def lifted_func_0( x2: R.Tensor((10, 5), "float32"), y2: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): s: R.Tensor((10, 5), "float32") = R.add(x2, y2) return s - @R.function + @R.function(private=True) def lifted_func_1( x21: R.Tensor((10, 5), "float32"), y21: R.Tensor((10, 5), "float32") ) -> R.Tensor((10, 5), "float32"): @@ -308,7 +308,7 @@ def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor(dtype="float32", ndim= def test_impure_function(): @tvm.script.ir_module class Expected: - @R.function(pure=False) + @R.function(pure=False, private=True) def lifted_func_0() -> R.Tuple: y = R.print(format="Wow!") return y diff --git a/tests/python/relax/test_transform_merge_composite_functions.py b/tests/python/relax/test_transform_merge_composite_functions.py index 61df388c7888..f294b62bc7c8 100644 --- a/tests/python/relax/test_transform_merge_composite_functions.py +++ b/tests/python/relax/test_transform_merge_composite_functions.py @@ -41,7 +41,7 @@ def main( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight11: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -57,7 +57,7 @@ def fused_relax_nn_conv2d_relax_nn_relu( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_nn_conv2d_relax_nn_relu1( conv1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight21: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -164,7 +164,7 @@ def main( R.output(gv2) return gv2 - @R.function + @R.function(private=True) def fused_relax_nn_gelu( lv: R.Tensor((1, 64, 54, 54), dtype="float32") ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): @@ -174,7 +174,7 @@ def fused_relax_nn_gelu( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_relu( lv1: R.Tensor((1, 64, 54, 54), dtype="float32") ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): @@ -184,7 +184,7 @@ def fused_relax_nn_relu( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_add( lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), @@ -195,7 +195,7 @@ def fused_relax_add( R.output(gv3) return gv3 - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -332,7 +332,7 @@ def main( R.output(gv2) return gv2 - @R.function + @R.function(private=True) def fused_relax_nn_gelu( lv: R.Tensor((1, 64, 54, 54), dtype="float32") ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): @@ -342,7 +342,7 @@ def fused_relax_nn_gelu( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_nn_relu( lv1: R.Tensor((1, 64, 54, 54), dtype="float32") ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): @@ -352,7 +352,7 @@ def fused_relax_nn_relu( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_add( lv5: R.Tensor((1, 64, 54, 54), dtype="float32"), gelu1: R.Tensor((1, 64, 54, 54), dtype="float32"), @@ -363,7 +363,7 @@ def fused_relax_add( R.output(gv3) return gv3 - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -509,7 +509,7 @@ def main( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_nn_relu( x11: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -519,7 +519,7 @@ def fused_relax_nn_relu( R.output(gv2) return gv2 - @R.function + @R.function(private=True) def fused_relax_nn_gelu( x21: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -529,7 +529,7 @@ def fused_relax_nn_gelu( R.output(gv3) return gv3 - @R.function + @R.function(private=True) def fused_relax_add( lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -627,7 +627,7 @@ def main(x1: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32 R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_nn_relu( x11: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -637,7 +637,7 @@ def fused_relax_nn_relu( R.output(gv2) return gv2 - @R.function + @R.function(private=True) def fused_relax_nn_gelu( x21: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -647,7 +647,7 @@ def fused_relax_nn_gelu( R.output(gv3) return gv3 - @R.function + @R.function(private=True) def fused_relax_add( lv: R.Tensor((10,), dtype="float32"), gelu1: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -759,7 +759,7 @@ def main( R.output(gv1) return gv1 - @R.function + @R.function(private=True) def fused_relax_nn_relu( add2: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -769,7 +769,7 @@ def fused_relax_nn_relu( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_add( x11: R.Tensor((10,), dtype="float32"), x21: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -779,7 +779,7 @@ def fused_relax_add( R.output(gv2) return gv2 - @R.function + @R.function(private=True) def fused_relax_nn_gelu( x31: R.Tensor((10,), dtype="float32") ) -> R.Tensor((10,), dtype="float32"): @@ -924,7 +924,7 @@ def main( R.output(conv) return conv - @R.function + @R.function(private=True) def fused_relax_nn_conv2d( data1: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), @@ -988,6 +988,9 @@ def lv1( def check(mod, expected): partitioned = relax.transform.MergeCompositeFunctions()(mod) + print(partitioned.script()) + print() + print(expected.script()) tvm.ir.assert_structural_equal(partitioned, expected) @@ -1071,7 +1074,7 @@ def test_reshape(): # Verify that the non-CallNode input (shape in reshape) can be handled properly. @I.ir_module class Module: - @R.function + @R.function(private=True) def fused_relax_matmul( lv: R.Tensor((1, 784), dtype="float32"), lv1: R.Tensor((784, 512), dtype="float32") ) -> R.Tensor((1, 512), dtype="float32"): @@ -1081,7 +1084,7 @@ def fused_relax_matmul( R.output(gv) return gv - @R.function + @R.function(private=True) def fused_relax_reshape( inp_0: R.Tensor((1, 1, 28, 28), dtype="float32"), param_0: R.Shape([1, 784]) ) -> R.Tensor((1, 784), dtype="float32"): diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py index 931d206afbb1..66abb2f027ae 100644 --- a/tests/python/relax/test_transform_rewrite_cuda_graph.py +++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py @@ -82,7 +82,7 @@ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T T.writes(compute[i0, i1]) compute[i0, i1] = T.exp(rxplaceholder[i0, i1], dtype="float32") - @R.function + @R.function(private=True) def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object): R.func_attr({"relax.force_pure": True}) storage: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) @@ -91,7 +91,7 @@ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object): gv: R.Tuple(R.Object, R.Object, R.Object) = (storage, storage1, storage2) return gv - @R.function + @R.function(private=True) def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): R.func_attr({"relax.force_pure": True}) cls = Expected @@ -193,7 +193,7 @@ def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T T.writes(compute[i0, i1]) compute[i0, i1] = T.exp(rxplaceholder[i0, i1]) - @R.function + @R.function(private=True) def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): R.func_attr({"relax.force_pure": True}) storage: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) @@ -201,7 +201,7 @@ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): gv: R.Tuple(R.Object, R.Object) = (storage, storage1) return gv - @R.function + @R.function(private=True) def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): R.func_attr({"relax.force_pure": True}) cls = Expected diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 7a8bcdee26ec..0206459558ac 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -27,6 +27,8 @@ from tvm.script.parser import relax as R from tvm.script.parser import tir as T +from tvm.relax.testing import dump_ast + def _check( parsed: Union[relax.Function, IRModule], @@ -203,7 +205,7 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): x = relax.Var("x", R.Tensor((128, 128), "float32")) bb = relax.BlockBuilder() - with bb.function("foo", (x,)): + with bb.function("foo", (x,), {"global_symbol": "foo"}): out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") bb.emit_func_output(out) @@ -232,7 +234,7 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): x = relax.Var("x", R.Tensor((128, 128), "float32")) bb = relax.BlockBuilder() - with bb.function("foo", (x,)): + with bb.function("foo", (x,), {"global_symbol": "foo"}): out = bb.emit_te( lambda x: x + 1, x, @@ -254,7 +256,7 @@ def main(x: R.Tensor((10, 20), "float32")) -> R.Tensor((10, 20), dtype="float32" bb = relax.BlockBuilder() x = relax.Var("x", relax.TensorStructInfo([10, 20], "float32")) - with bb.function("main", [x]): + with bb.function("main", [x], {"global_symbol": "main"}): lv1 = bb.emit_te(topi.add, x, x) out = bb.emit_te(topi.multiply, lv1, lv1) bb.emit_func_output(out) @@ -294,7 +296,7 @@ def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), "float32"): x = relax.Var("x", R.Tensor((128, 128), "float32")) bb = relax.BlockBuilder() - with bb.function("foo", (x,)): + with bb.function("foo", (x,), {"global_symbol": "foo"}): out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") bb.emit_func_output(out) mod = bb.get() @@ -834,7 +836,7 @@ def foo(x: R.Tensor((), "float32")): def test_call_tir_empty_tuple_arg(): bb = relax.BlockBuilder() dummy_param = relax.Var("dummy_param", R.Tensor(())) - with bb.function("foo", [dummy_param]): + with bb.function("foo", [dummy_param], {"global_symbol": "foo"}): output = bb.emit_te(topi.full, shape=(16, 32), dtype="float32", fill_value=1.0) bb.emit_func_output(output) @@ -1493,5 +1495,22 @@ def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: _check(foo, bb.get()["foo"]) +def test_private_function(): + @I.ir_module + class Addition: + @R.function(private=True) + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.add(x, x) + return y + + x = relax.Var("x", R.Tensor((), "int32")) + bb = relax.BlockBuilder() + with bb.function("main", (x)): + y = bb.emit(R.add(x, x)) + bb.emit_func_output(y) + + _check(Addition, bb.get()) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 7525c63be440..5b9e232cf3cd 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -47,8 +47,10 @@ def main(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): def test_extern_func(): + # note: this function will be treated as private unless a global symbol is added @R.function def relax_func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore + R.func_attr({"global_symbol": "func"}) return a obj = IRModule( @@ -576,5 +578,102 @@ def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): ) +def test_private_function(): + @I.ir_module + class AddMod: + @R.function(private=True) + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y: R.Tensor((), dtype="int32") = R.add(x, x) + return y + + _assert_print( + AddMod, + """ +# from tvm.script import ir as I +# from tvm.script import relax as R + +@I.ir_module +class Module: + @R.function(private=True) + def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + y: R.Tensor((), dtype="int32") = R.add(x, x) + return y +""", + ) + + +def test_directly_construct_private_funcs(): + # public + @R.function + def func1(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.func_attr({"global_symbol": "foo"}) + y: R.Tensor((), dtype="int32") = R.add(x, x) + return y + + # private + @R.function + def func2(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y: R.Tensor((), dtype="int32") = R.multiply(x, x) + return y + + # public but there's another attribute + @R.function + def func3(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.func_attr({"global_symbol": "baz", "relax.force_pure": True}) + y: R.Tuple = R.print(format="Hi there!") + z: R.Tensor((), dtype="int32") = R.add(x, x) + return z + + # private with an attribute + @R.function + def func4(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.func_attr({"relax.force_pure": True}) + y: R.Tuple = R.print(format="Lol") + z: R.Tensor((), dtype="int32") = R.multiply(x, x) + return z + + obj = IRModule( + { + "foo": func1, + "bar": func2, + "baz": func3, + "quux": func4, + } + ) + _assert_print( + obj, + """ +# from tvm.script import ir as I +# from tvm.script import relax as R + +@I.ir_module +class Module: + @R.function(private=True) + def bar(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + y: R.Tensor((), dtype="int32") = R.multiply(x, x) + return y + + @R.function + def baz(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + R.func_attr({"relax.force_pure": 1}) + y: R.Tuple = R.print(format=R.str("Hi there!")) + z: R.Tensor((), dtype="int32") = R.add(x, x) + return z + + @R.function + def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + y: R.Tensor((), dtype="int32") = R.add(x, x) + return y + + @R.function(private=True) + def quux(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + R.func_attr({"relax.force_pure": 1}) + y: R.Tuple = R.print(format=R.str("Lol")) + z: R.Tensor((), dtype="int32") = R.multiply(x, x) + return z +""", + ) + + if __name__ == "__main__": tvm.testing.main() From 4cd11a265e07f47d470cd25fcc5a36d05954652f Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 21 Jun 2023 18:19:04 -0400 Subject: [PATCH 02/26] Remove unused import --- tests/python/relax/test_tvmscript_parser.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 0206459558ac..ee67352c2994 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -27,9 +27,6 @@ from tvm.script.parser import relax as R from tvm.script.parser import tir as T -from tvm.relax.testing import dump_ast - - def _check( parsed: Union[relax.Function, IRModule], expect: Optional[Union[relax.Function, IRModule]] = None, From dabcb1e1a0b5d962baddbd595b9633ed9acd9141 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 21 Jun 2023 18:22:16 -0400 Subject: [PATCH 03/26] Simplify function frame construction --- src/script/ir_builder/relax/frame.cc | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 67114bd7101b..9e7ce0a46285 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -82,13 +82,7 @@ void FunctionFrameNode::ExitWithScope() { if (decl->attrs.defined()) { auto attr_dict = decl->attrs.get()->dict; if (attr_dict.count("global_symbol") && !attrs.count("global_symbol")) { - Map new_attrs; - for (auto kv : attrs) { - new_attrs.Set(kv.first, kv.second); - } - new_attrs.Set("global_symbol", attr_dict.at("global_symbol")); - auto mut_f = func.CopyOnWrite(); - mut_f->attrs = DictAttrs(new_attrs); + func = std::move(WithAttr(func, tvm::attr::kGlobalSymbol, attr_dict.at("global_symbol"))); } } } From ebd1b80be29d75dec56203587f58a58c421518d9 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 21 Jun 2023 18:23:30 -0400 Subject: [PATCH 04/26] Remove debug prints --- tests/python/relax/test_transform_merge_composite_functions.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/python/relax/test_transform_merge_composite_functions.py b/tests/python/relax/test_transform_merge_composite_functions.py index f294b62bc7c8..d55226613137 100644 --- a/tests/python/relax/test_transform_merge_composite_functions.py +++ b/tests/python/relax/test_transform_merge_composite_functions.py @@ -988,9 +988,6 @@ def lv1( def check(mod, expected): partitioned = relax.transform.MergeCompositeFunctions()(mod) - print(partitioned.script()) - print() - print(expected.script()) tvm.ir.assert_structural_equal(partitioned, expected) From e2ecadc3ed8c729d81fce21eb8f54c0e9dd022e7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 21 Jun 2023 22:41:11 -0400 Subject: [PATCH 05/26] Whitespace fix --- tests/python/relax/test_tvmscript_parser.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index ee67352c2994..4071d9429525 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -27,6 +27,7 @@ from tvm.script.parser import relax as R from tvm.script.parser import tir as T + def _check( parsed: Union[relax.Function, IRModule], expect: Optional[Union[relax.Function, IRModule]] = None, From f2c999d2ca9b516fcff0c8c2f86d22e1fd73a006 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 22 Jun 2023 13:18:35 -0400 Subject: [PATCH 06/26] Include global_symbol in importers --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 2 +- python/tvm/relax/frontend/stablehlo/stablehlo_translator.py | 3 ++- python/tvm/relax/frontend/torch/fx_translator.py | 5 ++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index d653bb551113..9edd7e914f7e 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1836,7 +1836,7 @@ def from_onnx(self, graph: onnx.onnx_ml_pb2.ModelProto, opset: int) -> IRModule: output_var = self.bb.emit_output(outputs) # Create function attributes for this module - func_attrs = {"num_input": self._num_input} + func_attrs = {"num_input": self._num_input, "global_symbol": "main"} # Create a function from our output expression and all input variables. input_list = [value for value in self._inputs.values() if isinstance(value, relax.Var)] # Attach params if they are available. diff --git a/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py b/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py index 1ca0856f63a4..f4ba7fcca3da 100644 --- a/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py +++ b/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py @@ -387,9 +387,10 @@ def from_stablehlo(self, model, input_info: List[Tuple[Tuple[int], str]]) -> tvm # Initialize the block builder with a function and a dataflow block. # Raise error if the input stablehlo op is impure func_name = "main" + func_attrs = {"global_symbol": func_name} self.block_builder = relax.BlockBuilder() - with self.block_builder.function(name=func_name, params=inputs.copy()): + with self.block_builder.function(name=func_name, params=inputs.copy(), func_attrs=func_attrs): output = None with self.block_builder.dataflow(): block = model.body.operations[0].regions[0].blocks[0] diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index ab6a707cb0ce..05aa08425a39 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1245,18 +1245,17 @@ def from_fx( # Initialize the block builder with a function and a dataflow block. func_name = "main" + func_attrs = {"global_symbol": func_name} self.block_builder = relax.BlockBuilder() params = [] if keep_params_as_input: - func_attrs = {"num_input": len(inputs)} + func_attrs["num_input"] = len(inputs) for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): shape = param.data.shape dtype = self._convert_data_type(str(param.data.dtype)) inputs.append(relax.Var(name, relax.TensorStructInfo(shape, dtype))) self.params[param] = inputs[-1] params.append(tvm.nd.array(param.data.cpu().numpy())) - else: - func_attrs = None with self.block_builder.function(name=func_name, params=inputs.copy(), attrs=func_attrs): output = None From e012683a9b34d72e156ec8217cc110e5f9d1d834 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 22 Jun 2023 15:15:04 -0400 Subject: [PATCH 07/26] Formatting fix --- python/tvm/relax/frontend/stablehlo/stablehlo_translator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py b/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py index f4ba7fcca3da..4910d829db76 100644 --- a/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py +++ b/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py @@ -390,7 +390,9 @@ def from_stablehlo(self, model, input_info: List[Tuple[Tuple[int], str]]) -> tvm func_attrs = {"global_symbol": func_name} self.block_builder = relax.BlockBuilder() - with self.block_builder.function(name=func_name, params=inputs.copy(), func_attrs=func_attrs): + with self.block_builder.function( + name=func_name, params=inputs.copy(), func_attrs=func_attrs + ): output = None with self.block_builder.dataflow(): block = model.body.operations[0].regions[0].blocks[0] From 24769c05bc23a512e0a8e04eac3e021dd3a59c49 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Thu, 22 Jun 2023 16:50:19 -0400 Subject: [PATCH 08/26] Add private field to function builder in BlockBuilder --- python/tvm/relax/block_builder.py | 12 ++++++ .../tvm/relax/frontend/onnx/onnx_frontend.py | 2 +- .../stablehlo/stablehlo_translator.py | 5 +-- .../tvm/relax/frontend/torch/fx_translator.py | 5 ++- python/tvm/relax/training/setup_trainer.py | 5 ++- src/relax/training/utils.cc | 4 +- src/relax/transform/gradient.cc | 8 ++-- tests/python/relax/test_training_loss.py | 10 +++++ tests/python/relax/test_training_optimizer.py | 8 ++++ .../relax/test_training_optimizer_numeric.py | 2 +- tests/python/relax/test_transform_fuse_ops.py | 38 ++++++++++++------- tests/python/relax/test_transform_fuse_tir.py | 10 ++--- tests/python/relax/test_transform_gradient.py | 34 ++++++++--------- 13 files changed, 95 insertions(+), 48 deletions(-) diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 80edb31efbf0..502073edf207 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -199,6 +199,7 @@ def function( name: str, params: Optional[Union[Var, Tuple, List[Var]]] = None, attrs: Optional[Dict[str, Object]] = None, + private: bool = False, ) -> FunctionScope: """Annotate a Relax function. @@ -215,6 +216,12 @@ def function( attrs : Dict[str, Object], optional The function attrs + private : bool, optional + Whether the function is annotated as private. + If the function is private, it will not have a global symbol attribute. + If it is not private and not an inner function, then it will have + a global symbol attribute (mapped to the function's name) + Returns ------- ret: FunctionScope @@ -233,6 +240,11 @@ def function( ) if attrs is None: attrs = {} + # The block builder does not permit nesting functions, per above comment, + # so no further check should be needed + if not private: + attrs["global_symbol"] = name + return FunctionScope(self, name, params, attrs) def testing_scope(self, def_vars: List[tir.Var]) -> TestingScope: diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 9edd7e914f7e..d653bb551113 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1836,7 +1836,7 @@ def from_onnx(self, graph: onnx.onnx_ml_pb2.ModelProto, opset: int) -> IRModule: output_var = self.bb.emit_output(outputs) # Create function attributes for this module - func_attrs = {"num_input": self._num_input, "global_symbol": "main"} + func_attrs = {"num_input": self._num_input} # Create a function from our output expression and all input variables. input_list = [value for value in self._inputs.values() if isinstance(value, relax.Var)] # Attach params if they are available. diff --git a/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py b/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py index 4910d829db76..1ca0856f63a4 100644 --- a/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py +++ b/python/tvm/relax/frontend/stablehlo/stablehlo_translator.py @@ -387,12 +387,9 @@ def from_stablehlo(self, model, input_info: List[Tuple[Tuple[int], str]]) -> tvm # Initialize the block builder with a function and a dataflow block. # Raise error if the input stablehlo op is impure func_name = "main" - func_attrs = {"global_symbol": func_name} self.block_builder = relax.BlockBuilder() - with self.block_builder.function( - name=func_name, params=inputs.copy(), func_attrs=func_attrs - ): + with self.block_builder.function(name=func_name, params=inputs.copy()): output = None with self.block_builder.dataflow(): block = model.body.operations[0].regions[0].blocks[0] diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 05aa08425a39..ab6a707cb0ce 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -1245,17 +1245,18 @@ def from_fx( # Initialize the block builder with a function and a dataflow block. func_name = "main" - func_attrs = {"global_symbol": func_name} self.block_builder = relax.BlockBuilder() params = [] if keep_params_as_input: - func_attrs["num_input"] = len(inputs) + func_attrs = {"num_input": len(inputs)} for name, param in sorted(model.named_parameters(), key=lambda x: x[0]): shape = param.data.shape dtype = self._convert_data_type(str(param.data.dtype)) inputs.append(relax.Var(name, relax.TensorStructInfo(shape, dtype))) self.params[param] = inputs[-1] params.append(tvm.nd.array(param.data.cpu().numpy())) + else: + func_attrs = None with self.block_builder.function(name=func_name, params=inputs.copy(), attrs=func_attrs): output = None diff --git a/python/tvm/relax/training/setup_trainer.py b/python/tvm/relax/training/setup_trainer.py index 81ecaf4ea5d7..2e2057086904 100644 --- a/python/tvm/relax/training/setup_trainer.py +++ b/python/tvm/relax/training/setup_trainer.py @@ -198,7 +198,10 @@ def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRM # Add optimizer function. self._optimizer.init(params) - mod[self.OPTIMIZER_FUNC] = self._optimizer.get_function() + # Need the global symbol to match the function's name + mod[self.OPTIMIZER_FUNC] = self._optimizer.get_function().with_attr( + "global_symbol", self.OPTIMIZER_FUNC + ) # Module attrs mod = mod.with_attrs( diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc index 37582e301550..19faaad58b87 100644 --- a/src/relax/training/utils.cc +++ b/src/relax/training/utils.cc @@ -48,7 +48,9 @@ class AppendLossMutator : private ExprMutator { Function new_loss_func = CopyWithNewVars(loss_function); AppendLossMutator mutator(mod, new_loss_func, num_backbone_outputs); - auto new_func_transformed = Downcast(mutator.VisitExpr(new_func)); + auto new_func_transformed = + WithAttr(Downcast(mutator.VisitExpr(new_func)), tvm::attr::kGlobalSymbol, + new_func_name.value_or(func_name + "_loss")); auto new_module = GetRef(mod.CopyOnWrite()); auto new_var = GlobalVar(new_func_name.value_or(func_name + "_loss")); diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 65fb3b96d6c3..01540af598e2 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -333,12 +333,14 @@ class GradientMutator : private ExprMutator { } GradientMutator mutator(mod, require_grads_value, target_index); - // remove the global symbol if the original had one (the adjoint does not need a global symbol) + + // make the adjoint public + auto new_name = func_name + "_adjoint"; Function new_func_transformed = - WithoutAttr(Downcast(mutator.VisitExpr(new_func)), tvm::attr::kGlobalSymbol); + WithAttr(Downcast(mutator.VisitExpr(new_func)), tvm::attr::kGlobalSymbol, new_name); IRModule new_module = GetRef(mod.CopyOnWrite()); - new_module->Add(GlobalVar(func_name + "_adjoint"), new_func_transformed); + new_module->Add(GlobalVar(new_name), new_func_transformed); return new_module; } diff --git a/tests/python/relax/test_training_loss.py b/tests/python/relax/test_training_loss.py index 0a2418aad756..0d456ceb3873 100644 --- a/tests/python/relax/test_training_loss.py +++ b/tests/python/relax/test_training_loss.py @@ -46,6 +46,7 @@ def test_l1_loss(): def expected( predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "float32") ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "l1_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets) lv1: R.Tensor((3, 5), "float32") = R.abs(lv) @@ -70,6 +71,7 @@ def expected( b: R.Tensor((2, 4), "float32"), targets: R.Tensor((2, 4), "float32"), ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "forward_loss"}) with R.dataflow(): lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") out: R.Tensor((2, 4), "float32") = R.add(lv, b) @@ -93,6 +95,7 @@ def test_mse_loss(): def expected( predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "float32") ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "mse_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets) lv1: R.Tensor((3, 5), "float32") = R.multiply(lv, lv) @@ -117,6 +120,7 @@ def expected( b: R.Tensor((2, 4), "float32"), targets: R.Tensor((2, 4), "float32"), ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "forward_loss"}) with R.dataflow(): lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") out: R.Tensor((2, 4), "float32") = R.add(lv, b) @@ -143,6 +147,7 @@ def expected( targets: R.Tensor((3,), "int64"), weights: R.Tensor((5,), "float32"), ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "cross_entropy_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) gv: R.Tensor((), "float32") = R.nn.nll_loss( @@ -165,6 +170,7 @@ def test_cross_entropy_loss_without_weights(): def expected( predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3,), "int64") ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "cross_entropy_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) gv: R.Tensor((), "float32") = R.nn.nll_loss( @@ -195,6 +201,7 @@ def expected( targets: R.Tensor((2,), "int64"), weights: R.Tensor((4,), "float32"), ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "forward_loss"}) with R.dataflow(): lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="") out: R.Tensor((2, 4), "float32") = R.add(lv, b) @@ -224,6 +231,7 @@ def expected( targets: R.Tensor((3, 5), "int64"), weights: R.Tensor((5,), "float32"), ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "categorical_cross_entropy_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) lv: R.Tensor((), "float32") = -lv * targets.astype("float32") @@ -245,6 +253,7 @@ def test_categorical_cross_entropy_loss_without_weights(): def expected( predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), "int64") ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "categorical_cross_entropy_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) gv: R.Tensor((), "float32") = R.mean(-lv * targets.astype("float32")) @@ -270,6 +279,7 @@ def expected( targets: R.Tensor((3, 5), "int64"), weights: R.Tensor((5,), "float32"), ) -> R.Tensor((), "float32"): + R.func_attr({"global_symbol": "categorical_cross_entropy_loss"}) with R.dataflow(): lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, axis=-1) targets = relax.op.reshape( diff --git a/tests/python/relax/test_training_optimizer.py b/tests/python/relax/test_training_optimizer.py index b2246087c6d3..514422da8d01 100644 --- a/tests/python/relax/test_training_optimizer.py +++ b/tests/python/relax/test_training_optimizer.py @@ -67,6 +67,7 @@ def sgd_expected( R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), R.Tuple(R.Tensor((), "int64")), ): + R.func_attr({"global_symbol": "SGD"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] @@ -104,6 +105,7 @@ def sgd_expected( R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), R.Tuple(R.Tensor((), "int64")), ): + R.func_attr({"global_symbol": "SGD"}) with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] num_steps_new: R.Tensor((), "int64") = R.add(num_steps, R.const(1, "int64")) @@ -146,6 +148,7 @@ def msgd_expected( R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), R.Tuple(R.Tensor((), "int64"), R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), ): + R.func_attr({"global_symbol": "MomentumSGD"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] @@ -195,6 +198,7 @@ def msgd_expected( R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), R.Tuple(R.Tensor((), "int64"), R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), ): + R.func_attr({"global_symbol": "MomentumSGD"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] @@ -250,6 +254,7 @@ def msgd_expected( R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), R.Tuple(R.Tensor((), "int64"), R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), ): + R.func_attr({"global_symbol": "MomentumSGD"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] @@ -321,6 +326,7 @@ def adam_expected( R.Tensor((3,), "float32"), ), ): + R.func_attr({"global_symbol": "Adam"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] @@ -418,6 +424,7 @@ def adam_expected( R.Tensor((3,), "float32"), ), ): + R.func_attr({"global_symbol": "Adam"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] @@ -519,6 +526,7 @@ def adam_expected( R.Tensor((3,), "float64"), ), ): + R.func_attr({"global_symbol": "Adam"}) # block 0 with R.dataflow(): num_steps: R.Tensor((), "int64") = optim_states[0] diff --git a/tests/python/relax/test_training_optimizer_numeric.py b/tests/python/relax/test_training_optimizer_numeric.py index 3b300e826120..23db8987f12d 100644 --- a/tests/python/relax/test_training_optimizer_numeric.py +++ b/tests/python/relax/test_training_optimizer_numeric.py @@ -69,7 +69,7 @@ def _test_optimizer(target, dev, np_func, opt_type, *args, **kwargs): x = relax.Var("x", R.Tensor((3, 3), "float32")) y = relax.Var("y", R.Tensor((3,), "float32")) opt = opt_type(*args, **kwargs).init([x, y]) - mod = IRModule.from_expr(opt.get_function()) + mod = IRModule.from_expr(opt.get_function().with_attr("global_symbol", "main")) tvm_func = _legalize_and_build(mod, target, dev)["main"] param_arr = [np.random.rand(3, 3).astype(np.float32), np.random.rand(3).astype(np.float32)] diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 6a30aabf3570..b04677b6f51e 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -48,7 +48,7 @@ def expected(): x = relax.Var("x", R.Tensor([10, 20], "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) lv1 = bb.emit_te(topi.exp, lv0) @@ -100,7 +100,9 @@ def expected(dtype): x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype)) p0 = relax.Var("p0", R.Tensor((), dtype)) - with bb.function("fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": 1}): + with bb.function( + "fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": 1}, private=True + ): with bb.dataflow(): lv0 = bb.emit_te( topi.nn.conv2d, @@ -119,7 +121,7 @@ def expected(dtype): x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype)) y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype)) - with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": 1}): + with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te( topi.nn.conv2d, @@ -196,7 +198,9 @@ def expected(): x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) w = relax.Var("w", R.Tensor((1, 16, 32, 32), "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_upsampling_concatenate_add", [w, x, p0], attrs={"Primitive": 1}): + with bb.function( + "fused_upsampling_concatenate_add", [w, x, p0], attrs={"Primitive": 1}, private=True + ): with bb.dataflow(): lv0 = bb.emit_te(topi.nn.upsampling, w, scale_h=2.0, scale_w=2.0) lv1 = bb.emit_te(topi.concatenate, (lv0, x), axis=1) @@ -287,7 +291,10 @@ def expected(dim: int): # Grouped function dense = relax.Var("dense", R.Tensor((1, 3 * dim), "float32")) with bb.function( - "fused_split_sigmoid_tanh_exp_multiply_add", [dense], attrs={"Primitive": 1} + "fused_split_sigmoid_tanh_exp_multiply_add", + [dense], + attrs={"Primitive": 1}, + private=True, ): with bb.dataflow(): lv0 = bb.emit_te(topi.split, dense, indices_or_sections=3, axis=1) @@ -340,7 +347,7 @@ def expected(dim: int): # Grouped function x = relax.Var("x", R.Tensor((1, 3 * dim), "float32")) - with bb.function("fused_split", [x], attrs={"Primitive": 1}): + with bb.function("fused_split", [x], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.split, x, indices_or_sections=3, axis=1) gv = bb.emit_output(relax.TupleGetItem(lv0, 0)) @@ -398,6 +405,7 @@ def expected(): "fused_squeeze_add_squeeze1_add_add_add_concatenate_squeeze2_add1", [x, p0, p1, p2, p3, p4], attrs={"Primitive": 1}, + private=True, ): with bb.dataflow(): lv0 = bb.emit_te(topi.squeeze, x) @@ -500,6 +508,7 @@ def expected(): "fused_add_add_add_concatenate_add1_add_add_add_concatenate_add1_add_add_add_concatenate_add1_concatenate1", [x, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11], attrs={"Primitive": 1}, + private=True, ): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) @@ -523,7 +532,7 @@ def expected(): # Grouped function 2 concat = relax.Var("concat", R.Tensor((1, 144, 64, 64), "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_pool2d_add2", [concat, p0], attrs={"Primitive": 1}): + with bb.function("fused_pool2d_add2", [concat, p0], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te( topi.nn.pool2d, @@ -609,7 +618,7 @@ def expected(): # Grouped function 1 x = relax.Var("x", R.Tensor((1, 16, 64, 64), "float32")) w = relax.Var("w", R.Tensor((16, 16, 3, 3), "float32")) - with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1}): + with bb.function("fused_conv2d_relu", [x, w], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te( topi.nn.conv2d, @@ -626,7 +635,7 @@ def expected(): # Grouped function 2 x = relax.Var("x", R.Tensor((1, 32, 64, 64), "float32")) w = relax.Var("w", R.Tensor((16, 32, 3, 3), "float32")) - with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1}): + with bb.function("fused_conv2d1_relu", [x, w], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te( topi.nn.conv2d, @@ -689,7 +698,10 @@ def expected(): x = relax.Var("x", R.Tensor((10, 20), "int32")) p0 = relax.Var("p0", R.Tensor((), "int32")) with bb.function( - "fused_add_squeeze_transpose_transpose1_left_shift", [x, p0], attrs={"Primitive": 1} + "fused_add_squeeze_transpose_transpose1_left_shift", + [x, p0], + attrs={"Primitive": 1}, + private=True, ): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) @@ -734,7 +746,7 @@ def expected(): # Grouped function x = relax.Var("x", R.Tensor((16, 16), "float32")) - with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1}): + with bb.function("fused_softmax_cast", [x], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.nn.softmax, x) gv = bb.emit_output(bb.call_te(topi.cast, lv0, dtype="float16")) @@ -781,7 +793,7 @@ def expected(): x = relax.Var("x", R.Tensor([10, 20], "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) lv1 = bb.emit_te(topi.exp, lv0) @@ -791,7 +803,7 @@ def expected(): x = relax.Var("x", R.Tensor([20, 10], "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_add1_exp1_squeeze1", [x, p0], attrs={"Primitive": 1}): + with bb.function("fused_add1_exp1_squeeze1", [x, p0], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) lv1 = bb.emit_te(topi.exp, lv0) diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index 1aca1a1b583f..f59e3f2e9e6f 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -32,7 +32,7 @@ def before(): x = relax.Var("x", R.Tensor([10, 20], "float32")) p0 = relax.Var("p0", R.Tensor([], "float32")) - with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": True}): + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": True}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) lv1 = bb.emit_te(topi.exp, lv0) @@ -565,7 +565,7 @@ def before(): x = relax.Var("x", R.Tensor([10, 20], "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}): + with bb.function("fused_add_exp_squeeze", [x, p0], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) lv1 = bb.emit_te(topi.exp, lv0) @@ -575,7 +575,7 @@ def before(): x = relax.Var("x", R.Tensor([20, 10], "float32")) p0 = relax.Var("p0", R.Tensor((), "float32")) - with bb.function("fused_add1_exp1_squeeze1", [x, p0], attrs={"Primitive": 1}): + with bb.function("fused_add1_exp1_squeeze1", [x, p0], attrs={"Primitive": 1}, private=True): with bb.dataflow(): lv0 = bb.emit_te(topi.add, x, p0) lv1 = bb.emit_te(topi.exp, lv0) @@ -603,7 +603,7 @@ def before(): @I.ir_module class Expected: - @R.function(private=True) + @R.function def func1(x: R.Tensor((10, 20), dtype="float32")) -> R.Tensor((10, 20), dtype="float32"): with R.dataflow(): gv2 = R.call_tir( @@ -614,7 +614,7 @@ def func1(x: R.Tensor((10, 20), dtype="float32")) -> R.Tensor((10, 20), dtype="f R.output(gv2) return gv2 - @R.function(private=True) + @R.function def func2(x: R.Tensor((20, 10), dtype="float32")) -> R.Tensor((20, 10), dtype="float32"): with R.dataflow(): gv3 = R.call_tir( diff --git a/tests/python/relax/test_transform_gradient.py b/tests/python/relax/test_transform_gradient.py index d84d76291959..50063fe385bb 100644 --- a/tests/python/relax/test_transform_gradient.py +++ b/tests/python/relax/test_transform_gradient.py @@ -44,7 +44,7 @@ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor(None, "float32", ndim=0): R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor(None, "float32", ndim=0),R.Tuple(R.Tensor(None, "float32", ndim=2)),): with R.dataflow(): gv: R.Tensor((), "float32") = R.sum(x, axis=None, keepdims=False) @@ -83,7 +83,7 @@ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tensor((3, 3), "float32") = x @@ -125,7 +125,7 @@ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tensor((3, 3), "float32") = R.add(x, x) @@ -168,7 +168,7 @@ def main(x: R.Tensor((3, 3), "float32")) -> R.Tensor((), "float32"): R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tensor((3, 3), "float32") = R.add(x, x) @@ -217,7 +217,7 @@ def main( R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tensor((3, 3), "float32") = R.add(x, y) @@ -249,7 +249,7 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32"), z: R.Te R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"))): # block 0 with R.dataflow(): @@ -292,7 +292,7 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tu R.output(lv1, lv2, lv3) return (lv1, lv2, lv3) - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((), "float32"), R.Tensor((), "float32")), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tensor((3, 3), "float32") = x @@ -341,7 +341,7 @@ def main(x: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), y R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (y, z) @@ -399,7 +399,7 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Te R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): # block 0 with R.dataflow(): @@ -473,7 +473,7 @@ def main(x: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")), y: R.Tensor((3, 3), "float32"), z: R.Tensor((3, 3), "float32"), u: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tuple(R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")), R.Tensor((3, 3), "float32")) = ((y, z), u) @@ -552,7 +552,7 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Te R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): with R.dataflow(): lv0: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")) = (x, y) @@ -621,7 +621,7 @@ def main(x: R.Tensor((6,), "float32")) -> R.Tensor((), "float32"): R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((6,), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((6,), "float32"))): with R.dataflow(): lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = R.split(x, indices_or_sections=2, axis=0) @@ -671,7 +671,7 @@ def main(x: R.Tensor((3,), "float32"), y: R.Tuple(R.Tensor((3,), "float32"), R.T R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3,), "float32"), y: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32"))) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3,), "float32"), R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")))): with R.dataflow(): lv1: R.Tuple(R.Tensor((3,), "float32"), R.Tensor((3,), "float32")) = (x, x) @@ -740,7 +740,7 @@ def main(x: R.Tensor((3,), "float32")) -> R.Tensor((), "float32"): R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3,), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3,), "float32"))): # block 0 with R.dataflow(): @@ -806,7 +806,7 @@ def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Te R.output(gv) return gv - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32"))): with R.dataflow(): lv1: R.Tensor((3, 3), "float32") = R.add(x, cst) @@ -1070,7 +1070,7 @@ def main(x: R.Tensor((3, 4), "float32")): @I.ir_module class Expected: - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3, 4), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((3, 4), dtype="float32"))): with R.dataflow(): s: R.Shape([3, 2, 2]) = R.shape([3, 2, 2]) @@ -1122,7 +1122,7 @@ def main( @I.ir_module class Expected: - @R.function(private=True) + @R.function def main_adjoint(x: R.Tensor((3, 10), dtype="float32"), w0: R.Tensor((10, 5), dtype="float32"), b0: R.Tensor((5,), dtype="float32"), label: R.Tensor((3, 5), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tuple(R.Tensor((10, 5), dtype="float32"), R.Tensor((5,), dtype="float32"))): with R.dataflow(): lv0: R.Tensor((3, 5), dtype="float32") = R.matmul(x, w0, out_dtype="void") From aaee0cd267d459987b1ae911bdd47daf44ec4388 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 23 Jun 2023 11:49:16 -0400 Subject: [PATCH 09/26] Whitespace fix --- src/relax/transform/gradient.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relax/transform/gradient.cc b/src/relax/transform/gradient.cc index 01540af598e2..2cda7a972d3c 100644 --- a/src/relax/transform/gradient.cc +++ b/src/relax/transform/gradient.cc @@ -333,11 +333,11 @@ class GradientMutator : private ExprMutator { } GradientMutator mutator(mod, require_grads_value, target_index); - + // make the adjoint public auto new_name = func_name + "_adjoint"; - Function new_func_transformed = - WithAttr(Downcast(mutator.VisitExpr(new_func)), tvm::attr::kGlobalSymbol, new_name); + Function new_func_transformed = WithAttr(Downcast(mutator.VisitExpr(new_func)), + tvm::attr::kGlobalSymbol, new_name); IRModule new_module = GetRef(mod.CopyOnWrite()); new_module->Add(GlobalVar(new_name), new_func_transformed); From 801000b606fec0e83d0284f6808efb443ed9c7db Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 23 Jun 2023 13:16:13 -0400 Subject: [PATCH 10/26] Set the global symbol for functions outside a module too --- include/tvm/script/ir_builder/relax/frame.h | 2 ++ include/tvm/script/ir_builder/relax/ir.h | 3 ++- python/tvm/script/ir_builder/relax/ir.py | 10 ++++++++-- python/tvm/script/parser/core/parser.py | 2 ++ python/tvm/script/parser/relax/parser.py | 19 +++++++++--------- src/script/ir_builder/relax/frame.cc | 17 +++++----------- src/script/ir_builder/relax/ir.cc | 3 ++- .../test_transform_combine_parallel_matmul.py | 20 +++++++++---------- .../python/relax/test_transform_normalize.py | 20 +++++++++---------- tests/python/relax/test_tvmscript_parser.py | 2 +- .../relax/test_tvmscript_printer_relax.py | 12 +++++------ 11 files changed, 56 insertions(+), 54 deletions(-) diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 9a8f835e819b..1ad681388912 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -99,6 +99,8 @@ class FunctionFrameNode : public SeqExprFrameNode { Optional ret_struct_info; /*! \brief Whether the function is annotated as pure */ Optional is_pure; + /*! \brief Whether the function is annotated as private */ + Optional is_private; /*! \brief The function attributes. */ Map attrs; /*! \brief The block builder to create Relax function. */ diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h index 1cf30b491957..d160ad090e48 100644 --- a/include/tvm/script/ir_builder/relax/ir.h +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -34,9 +34,10 @@ namespace relax { /*! * \brief Start a function frame. * \param is_pure Whether the function is annotated as pure. + * \param is_private Whether the function is annotated as private. * \return The created ir_builder Function frame. */ -TVM_DLL FunctionFrame Function(const Bool& is_pure); +TVM_DLL FunctionFrame Function(const Bool& is_pure, const Bool& is_private); /*! * \brief Add a parameter to the last function frame. diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index e54e4aa07b3b..bb3fda73c1a6 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -165,19 +165,25 @@ ############################### Function ################################ -def function(is_pure: bool = True) -> frame.FunctionFrame: +def function(is_pure: bool = True, is_private: bool = False) -> frame.FunctionFrame: """Start a function frame. Parameters ---------- is_pure: bool Whether the function is annotated as pure. + is_private : bool + Whether the function is annotated as private. + Returns ------- frame: FunctionFrame The constructed function frame. """ - return _ffi_api.Function(is_pure) # type: ignore[attr-defined] # pylint: disable=no-member + return _ffi_api.Function( + is_pure, + is_private, + ) # type: ignore[attr-defined] # pylint: disable=no-member def arg(name: py_str, struct_info: StructInfo) -> Var: diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 9275924466d5..0e9068f74cd5 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -240,6 +240,7 @@ class Parser(doc.NodeVisitor): dispatch_tokens: List[str] function_annotations: Optional[Dict[str, Dict[str, Any]]] var_table: VarTable + inside_function: bool # whether we are within a function def __init__( self, @@ -250,6 +251,7 @@ def __init__( self.dispatch_tokens = ["default"] self.function_annotations = function_annotations self.var_table = VarTable() + self.inside_function = False def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: """The main parse method for parser. diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index d69a841d4256..b8b7e28a1104 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -21,7 +21,7 @@ from typing import Any, Dict, Optional from tvm import relax, tir -from tvm.ir import GlobalVar, make_node, structural_equal +from tvm.ir import GlobalVar, structural_equal from tvm.relax import Expr, StructInfo from tvm.relax.utils import convert_to_expr from tvm.script.ir_builder.relax.frame import BlockFrame @@ -160,6 +160,9 @@ def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) -> Non @dispatch.register(token="relax", type_name="FunctionDef") def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + is_inner_function = self.inside_function + self.inside_function = True + # reserve a var for local function func_val = self.var_table.get().get(node.name) if not func_val and is_recursive(node): @@ -179,11 +182,12 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: self.var_table.add(node.name, local_func_var) purity = find_decorator_annotation(node, "pure") - # don't handle the privacy annotation here because it's only relevant for global funcs + # treat the function as private if we are inside another function or if it has a privacy annotation + privacy = is_inner_function or find_decorator_annotation(node, "private", default=False) with self.var_table.with_frame(): with self.with_dispatch_token("relax"): - with R.function(is_pure=purity): + with R.function(is_pure=purity, is_private=privacy): R.func_name(node.name) collect_symbolic_var_from_params(self, node) @@ -203,6 +207,7 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: self.report_error(stmt, "inline prim_func is disallowed in Relax IR") self.visit_body(node.body) + self.inside_function = is_inner_function def find_decorator_annotation(node: doc.FunctionDef, annotation: str, default: bool = True) -> bool: @@ -240,13 +245,7 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar is_pure = find_decorator_annotation(node, "pure") - # if the global function is not private, then use its name as the global symbol - is_private = find_decorator_annotation(node, "private", default=False) - attrs = None - if not is_private: - attrs = make_node("DictAttrs", global_symbol=node.name) - - func_signature = relax.Function.create_empty(params, ret_sinfo, is_pure=is_pure, attrs=attrs) + func_signature = relax.Function.create_empty(params, ret_sinfo, is_pure=is_pure) return I.decl_function(node.name, func_signature) diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 9e7ce0a46285..966af809c9b4 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -56,6 +56,11 @@ void FunctionFrameNode::ExitWithScope() { "`return` to return an Expr"; this->block_builder->BeginScope(params); Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); + // if the function is not private, add a global symbol to its attributes + if (!is_private.value_or(Bool(false))->value && name.defined() && + !attrs.count(tvm::attr::kGlobalSymbol)) { + attrs.Set(tvm::attr::kGlobalSymbol, name.value()); + } auto dict_attrs = attrs.empty() ? NullValue() : DictAttrs(attrs); this->block_builder->EndScope(); tvm::relax::Function func(/*params=*/params, @@ -74,18 +79,6 @@ void FunctionFrameNode::ExitWithScope() { "function scope, if it's defined in a Module"; const IRModuleFrame& frame = opt_frame.value(); const String& func_name = name.value_or(""); - // If the function has already been declared (i.e., it is global), see if there is - // already a global symbol defined for it (i.e., it is not private). - // If yes, add it to the current function's attributes (unless one was manually defined) - if (frame->global_var_map.count(func_name)) { - auto decl = frame->functions.at(frame->global_var_map.at(func_name)); - if (decl->attrs.defined()) { - auto attr_dict = decl->attrs.get()->dict; - if (attr_dict.count("global_symbol") && !attrs.count("global_symbol")) { - func = std::move(WithAttr(func, tvm::attr::kGlobalSymbol, attr_dict.at("global_symbol"))); - } - } - } if (!frame->global_var_map.count(func_name)) { // First time visiting the function. ir::DeclFunction(func_name, func); diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 52d9f0cfe10e..d66e8d059813 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -52,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(Namer, vtable) /////////////////////////////// Function //////////////////////////////// -FunctionFrame Function(const Bool& is_pure) { +FunctionFrame Function(const Bool& is_pure, const Bool& is_private) { ObjectPtr n = make_object(); const IRBuilder& ir_builder = IRBuilder::Current(); Optional mod = NullOpt; @@ -61,6 +61,7 @@ FunctionFrame Function(const Bool& is_pure) { } n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/mod); n->is_pure = is_pure; + n->is_private = is_private; return FunctionFrame(n); } diff --git a/tests/python/relax/test_transform_combine_parallel_matmul.py b/tests/python/relax/test_transform_combine_parallel_matmul.py index 66861f4ed485..97211f0dd0ff 100644 --- a/tests/python/relax/test_transform_combine_parallel_matmul.py +++ b/tests/python/relax/test_transform_combine_parallel_matmul.py @@ -97,7 +97,7 @@ def expected1( R.output(lv3) return lv3 - tvm.ir.assert_structural_equal(mod["main"], expected1) + tvm.ir.assert_structural_equal(mod["main"], expected1.with_attr("global_symbol", "main")) # Test a batched LHS case, slicing is done on the axis 2 mod = get_parallel_matmul(3, lhs_shape=(2, 1024, 640)) @@ -121,7 +121,7 @@ def expected2( R.output(lv3) return lv3 - tvm.ir.assert_structural_equal(mod["main"], expected2) + tvm.ir.assert_structural_equal(mod["main"], expected2.with_attr("global_symbol", "main")) def test_bias(): @@ -151,7 +151,7 @@ def expected1( R.output(lv6) return lv6 - tvm.ir.assert_structural_equal(mod["main"], expected1) + tvm.ir.assert_structural_equal(mod["main"], expected1.with_attr("global_symbol", "main")) mod = get_parallel_matmul(3, with_bias=[True, False, True]) mod = CombineParallelMatmul()(mod) @@ -178,7 +178,7 @@ def expected2( R.output(lv5) return lv5 - tvm.ir.assert_structural_equal(mod["main"], expected2) + tvm.ir.assert_structural_equal(mod["main"], expected2.with_attr("global_symbol", "main")) def test_activation(): @@ -204,7 +204,7 @@ def expected1( R.output(lv6) return lv6 - tvm.ir.assert_structural_equal(mod["main"], expected1) + tvm.ir.assert_structural_equal(mod["main"], expected1.with_attr("global_symbol", "main")) mod = get_parallel_matmul(3, activation=["gelu", "relu", "relu"]) mod = CombineParallelMatmul()(mod) @@ -230,7 +230,7 @@ def expected2( R.output(lv6) return lv6 - tvm.ir.assert_structural_equal(mod["main"], expected2) + tvm.ir.assert_structural_equal(mod["main"], expected2.with_attr("global_symbol", "main")) mod = get_parallel_matmul(3, activation=["relu", None, None]) mod = CombineParallelMatmul()(mod) @@ -255,7 +255,7 @@ def expected3( R.output(lv4) return lv4 - tvm.ir.assert_structural_equal(mod["main"], expected3) + tvm.ir.assert_structural_equal(mod["main"], expected3.with_attr("global_symbol", "main")) def test_bias_activation(): @@ -286,7 +286,7 @@ def expected1( R.output(lv9) return lv9 - tvm.ir.assert_structural_equal(mod["main"], expected1) + tvm.ir.assert_structural_equal(mod["main"], expected1.with_attr("global_symbol", "main")) mod = get_parallel_matmul(3, with_bias=[True, True, True], activation=["relu", None, "relu"]) mod = CombineParallelMatmul()(mod) @@ -316,7 +316,7 @@ def expected2( R.output(lv8) return lv8 - tvm.ir.assert_structural_equal(mod["main"], expected2) + tvm.ir.assert_structural_equal(mod["main"], expected2.with_attr("global_symbol", "main")) mod = get_parallel_matmul(3, with_bias=[True, False, True], activation=["relu", None, "relu"]) mod = CombineParallelMatmul()(mod) @@ -345,7 +345,7 @@ def expected3( R.output(lv7) return lv7 - tvm.ir.assert_structural_equal(mod["main"], expected3) + tvm.ir.assert_structural_equal(mod["main"], expected3.with_attr("global_symbol", "main")) def test_rhs_batched(): diff --git a/tests/python/relax/test_transform_normalize.py b/tests/python/relax/test_transform_normalize.py index 874e83c7f955..a6feb0b8abca 100644 --- a/tests/python/relax/test_transform_normalize.py +++ b/tests/python/relax/test_transform_normalize.py @@ -44,7 +44,7 @@ def test_normalize_function(): after_mod = relax.transform.Normalize()(before_mod) - @R.function + @R.function(private=True) def expected(x: R.Tensor(("m", "n"), "float16")) -> R.Tensor(dtype="float16", ndim=2): gv = R.add(x, x) gv1 = R.add(x, x) @@ -86,7 +86,7 @@ def test_normalize_if(): before_mod = tvm.IRModule.from_expr(f) after_mod = relax.transform.Normalize()(before_mod) - @R.function + @R.function(private=True) def expected( cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") ) -> R.Tensor(dtype="float32", ndim=1): @@ -151,7 +151,7 @@ def test_normalize_seq_body(): before_mod = tvm.IRModule.from_expr(f) after_mod = relax.transform.Normalize()(before_mod) - @R.function + @R.function(private=True) def expected( x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32") ) -> R.Tensor(ndim=0, dtype="int32"): @@ -175,7 +175,7 @@ def test_normalize_func_body(): before_mod = tvm.IRModule.from_expr(f) after_mod = relax.transform.Normalize()(before_mod) - @R.function + @R.function(private=True) def expected( x: R.Tensor((), dtype="int32"), y: R.Tensor((), dtype="int32") ) -> R.Tensor(ndim=0, dtype="int32"): @@ -207,7 +207,7 @@ def test_normalize_if_branches(): before_mod = tvm.IRModule.from_expr(f) after_mod = relax.transform.Normalize()(before_mod) - @R.function + @R.function(private=True) def expected( cond: R.Tensor((), dtype="bool"), x: R.Tensor((), dtype="int32"), @@ -257,7 +257,7 @@ def test_normalize_if_condition(): before_mod = tvm.IRModule.from_expr(f) after_mod = relax.transform.Normalize()(before_mod) - @R.function + @R.function(private=True) def expected( cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") ) -> R.Tensor(dtype="float32", ndim=1): @@ -341,7 +341,7 @@ def test_normalize_combine_nearby_blocks(): after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) - @R.function + @R.function(private=True) def expected(x: R.Tensor((), "int32")): with R.dataflow(): v0 = x @@ -383,7 +383,7 @@ def test_normalize_nested_seq(): ) after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) - @R.function + @R.function(private=True) def expected(): x = relax.const(1) z = relax.const(2) @@ -434,7 +434,7 @@ def test_normalize_nested_seq_dataflow(): ) after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) - @R.function + @R.function(private=True) def expected(): x = relax.const(1) q = relax.const(2) @@ -507,7 +507,7 @@ def test_normalize_deeply_nested_seq(): ) after_mod = relax.transform.Normalize()(tvm.IRModule.from_expr(f)) - @R.function + @R.function(private=True) def expected(): x = relax.const(1) u = relax.const(2) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 4071d9429525..9305cdbcb129 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1503,7 +1503,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): x = relax.Var("x", R.Tensor((), "int32")) bb = relax.BlockBuilder() - with bb.function("main", (x)): + with bb.function("main", (x), private=True): y = bb.emit(R.add(x, x)) bb.emit_func_output(y) diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 5b9e232cf3cd..9a8572aadc61 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -41,16 +41,15 @@ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore # from tvm.script import relax as R @R.function -def main(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): +def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): + R.func_attr({"global_symbol": "func"}) return a""", ) def test_extern_func(): - # note: this function will be treated as private unless a global symbol is added @R.function def relax_func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore - R.func_attr({"global_symbol": "func"}) return a obj = IRModule( @@ -606,12 +605,11 @@ def test_directly_construct_private_funcs(): # public @R.function def func1(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): - R.func_attr({"global_symbol": "foo"}) y: R.Tensor((), dtype="int32") = R.add(x, x) return y # private - @R.function + @R.function(private=True) def func2(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): y: R.Tensor((), dtype="int32") = R.multiply(x, x) return y @@ -619,13 +617,13 @@ def func2(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): # public but there's another attribute @R.function def func3(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): - R.func_attr({"global_symbol": "baz", "relax.force_pure": True}) + R.func_attr({"relax.force_pure": True}) y: R.Tuple = R.print(format="Hi there!") z: R.Tensor((), dtype="int32") = R.add(x, x) return z # private with an attribute - @R.function + @R.function(private=True) def func4(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): R.func_attr({"relax.force_pure": True}) y: R.Tuple = R.print(format="Lol") From d1f7c01bd84123e2442e1667f052fa50d10434b5 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 23 Jun 2023 13:45:54 -0400 Subject: [PATCH 11/26] Formatting fix --- python/tvm/script/parser/core/parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index 0e9068f74cd5..69e262b1d327 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -240,7 +240,7 @@ class Parser(doc.NodeVisitor): dispatch_tokens: List[str] function_annotations: Optional[Dict[str, Dict[str, Any]]] var_table: VarTable - inside_function: bool # whether we are within a function + inside_function: bool # whether we are within a function def __init__( self, From 5b6c8144a8f5d2c73f90d52e329293e166ad920c Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 23 Jun 2023 14:01:18 -0400 Subject: [PATCH 12/26] Print privacy attribute for functions not in an IRModule --- src/script/printer/ir_docsifier.cc | 2 +- src/script/printer/relax/function.cc | 22 ++++++- src/script/printer/utils.h | 14 ++++- .../relax/test_tvmscript_printer_relax.py | 60 ++++++++++++++++++- 4 files changed, 90 insertions(+), 8 deletions(-) diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 36ccf91d329e..3ae5e9158276 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -29,7 +29,7 @@ namespace printer { IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) { ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; - String name = GenerateUniqueName(name_hint, this->defined_names); + String name = GenerateUniqueName(name_hint, this->defined_names, this->cfg->binding_names); this->defined_names.insert(name); DocCreator doc_factory = [name]() { return IdDoc(name); }; obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}}); diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index a42685dcf828..fe67f7db4792 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -22,6 +22,21 @@ namespace tvm { namespace script { namespace printer { +bool AtTopLevelFunction(const IRDocsifier& d) { + // fewer than 2 frames: not in a function at all + if (d->frames.size() < 2) { + return false; + } + // if the first frame is a RelaxFrame, then this is not inside a module. + // 2 frames => we are at a function (more than 2 => nested function) + if (d->frames[0]->IsInstance()) { + return d->frames.size() == 2; + } + // otherwise the first two frames pertain to an IR module, + // so 3 frames => we are at a top-level function (more than 3 => nested function) + return d->frames.size() == 3; +} + TVM_REGISTER_NODE_TYPE(RelaxFrameNode); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -54,7 +69,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (n->attrs.defined() && !n->attrs->dict.empty()) { // if the function is a global function and has a global symbol, // then don't print the global symbol (it will be implicit from not being private) - if (d->frames.size() == 3 && n->attrs->dict.count("global_symbol")) { + if (AtTopLevelFunction(d) && n->attrs->dict.count("global_symbol")) { Map new_attrs; for (auto kv : n->attrs->dict) { if (kv.first != "global_symbol") { @@ -81,8 +96,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) dec_keys.push_back("pure"); dec_values.push_back(LiteralDoc::Boolean(false, Optional())); } - // if the function is global and does not have a global symbol, indicate that it's private - if (d->frames.size() == 3 && + // if the function is global or is not in a module and does not have a global symbol, + // indicate that it's private + if (AtTopLevelFunction(d) && (!n->attrs.defined() || !n->attrs->dict.count("global_symbol"))) { dec_keys.push_back("private"); dec_values.push_back(LiteralDoc::Boolean(true, Optional())); diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 1a5438577676..95edc6008564 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -161,13 +161,25 @@ inline Optional FindFunctionName(const IRDocsifier& d, const BaseFunc& f } inline String GenerateUniqueName(std::string name_hint, - const std::unordered_set& defined_names) { + const std::unordered_set& defined_names, + const Array& binding_names) { for (char& c : name_hint) { if (c != '_' && !std::isalnum(c)) { c = '_'; } } std::string name = name_hint; + // if the name matches the name currently being bound, then do not add a suffix + // (this comes up in the case of defining a local function: the local function + // is the RHS of a binding. The var name on the LHS will have been looked at first; + // without this check, we would print a suffix on the function name even though it + // is actually the first definition) + for (const auto& bound_name : binding_names) { + if (name == bound_name) { + return name; + } + } + for (int i = 1; defined_names.count(name) > 0; ++i) { name = name_hint + "_" + std::to_string(i); } diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 9a8572aadc61..d9996792a926 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -33,6 +33,7 @@ def _assert_print(obj, expected): def test_function(): @R.function def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore + R.func_attr({"some_attr": 1}) return a _assert_print( @@ -42,19 +43,38 @@ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore @R.function def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): - R.func_attr({"global_symbol": "func"}) + R.func_attr({"some_attr": 1}) + return a""", + ) + + +def test_lone_private_function(): + @R.function(private=True) + def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore + R.func_attr({"some_attr": 1}) + return a + + # name prints as main because without a global symbol, the printer cannot assume a name + _assert_print( + func, + """ +# from tvm.script import relax as R + +@R.function(private=True) +def main(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): + R.func_attr({"some_attr": 1}) return a""", ) def test_extern_func(): @R.function - def relax_func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore + def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): # type: ignore return a obj = IRModule( { - "func": relax_func, + "func": func, "my_ext": relax.ExternFunc("my_ext"), } ) @@ -74,6 +94,40 @@ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)): ) +def test_nested_function(): + @I.ir_module + class NestedFunction: + @R.function + def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + @R.function + def nested(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + return y + + z = nested(x) + return z + + _assert_print( + NestedFunction, + """ +# from tvm.script import ir as I +# from tvm.script import relax as R + +@I.ir_module +class Module: + @R.function + def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + # from tvm.script import relax as R + + @R.function + def nested(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): + return y + + z: R.Tensor((), dtype="int32") = nested(x) + return z +""", + ) + + def test_object_struct_info(): obj = relax.ObjectStructInfo() _assert_print( From 954dfd3d42cf863764fc8e46d205c4e52d92714c Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 23 Jun 2023 16:11:42 -0400 Subject: [PATCH 13/26] Fix placement of pylint overrides in ir.py --- python/tvm/script/ir_builder/relax/ir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index bb3fda73c1a6..c509a3c860a6 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -180,10 +180,10 @@ def function(is_pure: bool = True, is_private: bool = False) -> frame.FunctionFr frame: FunctionFrame The constructed function frame. """ - return _ffi_api.Function( + return _ffi_api.Function( # pylint: disable=no-member is_pure, is_private, - ) # type: ignore[attr-defined] # pylint: disable=no-member + ) # type: ignore[attr-defined] def arg(name: py_str, struct_info: StructInfo) -> Var: From 249ec0eb0a6a1be3d381f8dc44fefef5c197287f Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 23 Jun 2023 16:13:23 -0400 Subject: [PATCH 14/26] Fix line length --- python/tvm/script/parser/relax/parser.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index b8b7e28a1104..863c249975a7 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -182,7 +182,8 @@ def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: self.var_table.add(node.name, local_func_var) purity = find_decorator_annotation(node, "pure") - # treat the function as private if we are inside another function or if it has a privacy annotation + # treat the function as private if we are inside another function + # or if it has a privacy annotation privacy = is_inner_function or find_decorator_annotation(node, "private", default=False) with self.var_table.with_frame(): From 886a63eeb9db230d32c23f32b6705f2247973ed7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 23 Jun 2023 16:16:03 -0400 Subject: [PATCH 15/26] Check for nested function case in function.cc instead --- src/script/printer/ir_docsifier.cc | 2 +- src/script/printer/relax/function.cc | 10 +++++++++- src/script/printer/utils.h | 14 +------------- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 3ae5e9158276..36ccf91d329e 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -29,7 +29,7 @@ namespace printer { IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) { ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; - String name = GenerateUniqueName(name_hint, this->defined_names, this->cfg->binding_names); + String name = GenerateUniqueName(name_hint, this->defined_names); this->defined_names.insert(name); DocCreator doc_factory = [name]() { return IdDoc(name); }; obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}}); diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index fe67f7db4792..50eafa1c2a98 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -43,7 +43,15 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](relax::Function n, ObjectPath n_p, IRDocsifier d) -> Doc { std::unordered_set func_vars; With f(d); - IdDoc func_name = d->Define(n, f(), FindFunctionName(d, n).value_or("main")); + + IdDoc func_name(""); + // if we are binding a local definition, then calling d->Define + // will result in a repeated definition and an incorrect displayed name + if (Optional name = GetBindingName(d)) { + func_name = std::move(IdDoc(name.value())); + } else { + func_name = std::move(d->Define(n, f(), FindFunctionName(d, n).value_or("main"))); + } (*f)->AddDispatchToken(d, "relax"); (*f)->is_func = true; (*f)->func_vars = &func_vars; diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 95edc6008564..1a5438577676 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -161,25 +161,13 @@ inline Optional FindFunctionName(const IRDocsifier& d, const BaseFunc& f } inline String GenerateUniqueName(std::string name_hint, - const std::unordered_set& defined_names, - const Array& binding_names) { + const std::unordered_set& defined_names) { for (char& c : name_hint) { if (c != '_' && !std::isalnum(c)) { c = '_'; } } std::string name = name_hint; - // if the name matches the name currently being bound, then do not add a suffix - // (this comes up in the case of defining a local function: the local function - // is the RHS of a binding. The var name on the LHS will have been looked at first; - // without this check, we would print a suffix on the function name even though it - // is actually the first definition) - for (const auto& bound_name : binding_names) { - if (name == bound_name) { - return name; - } - } - for (int i = 1; defined_names.count(name) > 0; ++i) { name = name_hint + "_" + std::to_string(i); } From 4b4b52fc79aaade7728b1c3f24221f4aacccaf60 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 23 Jun 2023 16:45:40 -0400 Subject: [PATCH 16/26] Print the global symbol if it doesn't match the name for some reason --- src/script/printer/relax/function.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index 50eafa1c2a98..0892fbeee3b8 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -75,9 +75,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) (*f)->func_vars = nullptr; // Step 4. Print attributes if (n->attrs.defined() && !n->attrs->dict.empty()) { - // if the function is a global function and has a global symbol, - // then don't print the global symbol (it will be implicit from not being private) - if (AtTopLevelFunction(d) && n->attrs->dict.count("global_symbol")) { + // If the function is a global function and has a global symbol, + // then don't print the global symbol (it will be implicit from not being private). + // For a function without an IR module whose global symbol + // doesn't match the function name, we should still print the global symbol attribute. + if (AtTopLevelFunction(d) && n->attrs->dict.count(tvm::attr::kGlobalSymbol) && + n->attrs->dict.at(tvm::attr::kGlobalSymbol) == func_name->name) { Map new_attrs; for (auto kv : n->attrs->dict) { if (kv.first != "global_symbol") { From 8064cfe26e585a93de44eeb674a3dbb3b64812c7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 23 Jun 2023 16:54:42 -0400 Subject: [PATCH 17/26] Correctly coerce the attribute type --- src/script/printer/relax/function.cc | 4 ++-- .../python/relax/test_tvmscript_printer_relax.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index 0892fbeee3b8..86c5b62ab72f 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -80,10 +80,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // For a function without an IR module whose global symbol // doesn't match the function name, we should still print the global symbol attribute. if (AtTopLevelFunction(d) && n->attrs->dict.count(tvm::attr::kGlobalSymbol) && - n->attrs->dict.at(tvm::attr::kGlobalSymbol) == func_name->name) { + Downcast(n->attrs->dict.at(tvm::attr::kGlobalSymbol)) == func_name->name) { Map new_attrs; for (auto kv : n->attrs->dict) { - if (kv.first != "global_symbol") { + if (kv.first != tvm::attr::kGlobalSymbol) { new_attrs.Set(kv.first, kv.second); } } diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index d9996792a926..a4b2e9cd6c18 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -658,19 +658,19 @@ def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): def test_directly_construct_private_funcs(): # public @R.function - def func1(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): y: R.Tensor((), dtype="int32") = R.add(x, x) return y # private @R.function(private=True) - def func2(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + def bar(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): y: R.Tensor((), dtype="int32") = R.multiply(x, x) return y # public but there's another attribute @R.function - def func3(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + def baz(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): R.func_attr({"relax.force_pure": True}) y: R.Tuple = R.print(format="Hi there!") z: R.Tensor((), dtype="int32") = R.add(x, x) @@ -678,7 +678,7 @@ def func3(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): # private with an attribute @R.function(private=True) - def func4(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + def quux(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): R.func_attr({"relax.force_pure": True}) y: R.Tuple = R.print(format="Lol") z: R.Tensor((), dtype="int32") = R.multiply(x, x) @@ -686,10 +686,10 @@ def func4(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): obj = IRModule( { - "foo": func1, - "bar": func2, - "baz": func3, - "quux": func4, + "foo": foo, + "bar": bar, + "baz": baz, + "quux": quux, } ) _assert_print( From 25e4712b0372a07f1c549328ebe76a8b4a28c085 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Fri, 23 Jun 2023 19:12:58 -0400 Subject: [PATCH 18/26] Tweak pylint override again --- python/tvm/script/ir_builder/relax/ir.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index c509a3c860a6..b06d9547acdb 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -180,10 +180,9 @@ def function(is_pure: bool = True, is_private: bool = False) -> frame.FunctionFr frame: FunctionFrame The constructed function frame. """ - return _ffi_api.Function( # pylint: disable=no-member - is_pure, - is_private, - ) # type: ignore[attr-defined] + return _ffi_api.Function( # type: ignore[attr-defined] # pylint: disable=no-member + is_pure, is_private + ) def arg(name: py_str, struct_info: StructInfo) -> Var: From 8c7bab1810a5db371c89000f5d390f90556f8115 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Sun, 25 Jun 2023 11:34:05 -0400 Subject: [PATCH 19/26] Add pylint override for trailing whitespace in printer tests --- tests/python/relax/test_tvmscript_printer_relax.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index a4b2e9cd6c18..a61a80a8890e 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -106,6 +106,9 @@ def nested(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"): z = nested(x) return z + # The pretty-printer inserts the trailing whitespace itself; + # removing it would cause the test to fail + # pylint: disable=trailing-whitespace _assert_print( NestedFunction, """ @@ -126,6 +129,8 @@ def nested(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): return z """, ) + # pylint: enable=trailing-whitespace + def test_object_struct_info(): From b815449fdcfd5989b84aedfa9c0f14623a44d2b2 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Jun 2023 11:57:20 -0400 Subject: [PATCH 20/26] Fix whitespace --- tests/python/relax/test_tvmscript_printer_relax.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index a61a80a8890e..dcaa650149db 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -132,7 +132,6 @@ def nested(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): # pylint: enable=trailing-whitespace - def test_object_struct_info(): obj = relax.ObjectStructInfo() _assert_print( From 68142e12937c471e6bc235d38453a59165f367fe Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Jun 2023 13:52:56 -0400 Subject: [PATCH 21/26] Remove trailing whitespace altogether instead of trying to override it --- tests/python/relax/test_tvmscript_printer_relax.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index dcaa650149db..c37694317361 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -27,7 +27,9 @@ def _assert_print(obj, expected): if not isinstance(obj, str): obj = obj.script(verbose_expr=True) obj = obj.strip() - assert obj == expected.strip(), "\n" + obj + # compare line by line in case there is trailing whitespace in the _middle_ + for obj_line, expected_line in zip(obj.splitlines(), expected.strip().splitlines()): + assert obj_line.strip() == expected_line.strip(), "\n" + obj def test_function(): @@ -106,9 +108,6 @@ def nested(y: R.Tensor((), "int32")) -> R.Tensor((), "int32"): z = nested(x) return z - # The pretty-printer inserts the trailing whitespace itself; - # removing it would cause the test to fail - # pylint: disable=trailing-whitespace _assert_print( NestedFunction, """ @@ -120,7 +119,7 @@ class Module: @R.function def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): # from tvm.script import relax as R - + @R.function def nested(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): return y @@ -129,7 +128,6 @@ def nested(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): return z """, ) - # pylint: enable=trailing-whitespace def test_object_struct_info(): From 4e272aea4887f243ce5100ec4043aa813b320983 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Jun 2023 18:02:49 -0400 Subject: [PATCH 22/26] Fix test_utils --- tests/python/relax/test_utils.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index c55876a3ba2d..f0c4ae0bd2a3 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -71,7 +71,9 @@ def func_copied(x: R.Tensor((3,), "float32"), y: R.Tensor((3,), "float32")): gv = R.add(x, y) return gv - Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]) + Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]).with_attr( + "global_symbol", "func_copied" + ) # Assertion will fail if the f_copied contains the same VarNode that's used in # the original function, due to var mapping during structural equal. @@ -113,7 +115,9 @@ def inner(x: R.Tensor((3,), "float32")) -> R.Tensor((3,), dtype="float32"): gv = R.add(x, y) return gv - Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]) + Actual["func_copied"] = relax.utils.copy_with_new_vars(Actual["func"]).with_attr( + "global_symbol", "func_copied" + ) assert_structural_equal(Actual, Expected) From 5314469e657a42b18a0da543492102155fdd7df6 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Jun 2023 18:03:56 -0400 Subject: [PATCH 23/26] Fix test_tvmscript_parser_op_datatype --- tests/python/relax/test_tvmscript_parser_op_datatype.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_tvmscript_parser_op_datatype.py b/tests/python/relax/test_tvmscript_parser_op_datatype.py index ec71e868d45b..85c5faa8667b 100644 --- a/tests/python/relax/test_tvmscript_parser_op_datatype.py +++ b/tests/python/relax/test_tvmscript_parser_op_datatype.py @@ -47,7 +47,7 @@ def expected(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float16" gv = bb.emit(relax.op.astype(x, "float16")) bb.emit_func_output(gv) - _check(expected, bb.get()["main"]) + _check(expected.with_attr("global_symbol", "main"), bb.get()["main"]) if __name__ == "__main__": From a0df076ba9d60e8914b449cb1807d699cfb58fbf Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Jun 2023 18:05:54 -0400 Subject: [PATCH 24/26] Fix global symbols in torch dynamo importer --- python/tvm/relax/frontend/torch/dynamo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/frontend/torch/dynamo.py b/python/tvm/relax/frontend/torch/dynamo.py index 3015f77428fb..abdf7b8862fe 100644 --- a/python/tvm/relax/frontend/torch/dynamo.py +++ b/python/tvm/relax/frontend/torch/dynamo.py @@ -154,7 +154,8 @@ def _capture(graph_module: fx.GraphModule, example_inputs): keep_params_as_input=keep_params_as_input, unwrap_unit_return_tuple=True, ) - mod[f"subgraph_{len(mod.get_global_vars())}"] = mod_["main"] + new_name = f"subgraph_{len(mod.get_global_vars())}" + mod[new_name] = mod_["main"].with_attr("global_symbol", new_name) return graph_module.forward dynamo.reset() From d2a49029162ec1f14ec3d94cc4e029e7ad28e652 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Jun 2023 18:18:03 -0400 Subject: [PATCH 25/26] Fix test_dataflow_pattern --- tests/python/relax/test_dataflow_pattern.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index cbb19c674335..ea83807bf8cd 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -875,7 +875,7 @@ def rewriter(_, matchings): return R.multiply(matchings[x], R.const(2, "float32")) rewritten = rewrite_call(pattern, rewriter, main) - tvm.ir.assert_structural_equal(rewritten, expected1) + tvm.ir.assert_structural_equal(rewritten, expected1.with_attr("global_symbol", "main")) add1 = is_op("relax.add")(x, x) pattern = is_op("relax.add")(add1, add1) @@ -884,7 +884,7 @@ def rewriter(_, matchings): return R.multiply(matchings[x], R.const(4, "float32")) rewritten = rewrite_call(pattern, rewriter, main) - tvm.ir.assert_structural_equal(rewritten, expected2) + tvm.ir.assert_structural_equal(rewritten, expected2.with_attr("global_symbol", "main")) # No rewriting, return the original call node as is def rewriter(orig, _): @@ -959,7 +959,7 @@ def rewriter(_, matchings): return R.nn.attention(matchings[Q], matchings[K], matchings[V]) rewritten = rewrite_call(pattern, rewriter, main) - tvm.ir.assert_structural_equal(rewritten, expected) + tvm.ir.assert_structural_equal(rewritten, expected.with_attr("global_symbol", "main")) def test_attention_qkv(): @@ -1115,7 +1115,7 @@ def expected( inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 ) rewritten = rewrite_bindings(ctx, rewriter, qkv_x2) - tvm.ir.assert_structural_equal(rewritten, expected) + tvm.ir.assert_structural_equal(rewritten, expected.with_attr("global_symbol", "qkv_x2")) def test_combine_matmul_emit_order(): @@ -1173,7 +1173,7 @@ def expected( inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2, matmul3 ) rewritten = rewrite_bindings(ctx, rewriter, main) - tvm.ir.assert_structural_equal(rewritten, expected) + tvm.ir.assert_structural_equal(rewritten, expected.with_attr("global_symbol", "main")) # make sure it builds mod = tvm.IRModule() @@ -1272,7 +1272,7 @@ def rewriter(matchings, _): rewritten = rewrite_bindings(ctx, rewriter, main) print(rewritten.script()) - tvm.ir.assert_structural_equal(rewritten, expected) + tvm.ir.assert_structural_equal(rewritten, expected.with_attr("global_symbol", "main")) # make sure it builds mod = tvm.IRModule() From 5ac02e8889229cce42b35949d711eb4c434ce6c3 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Mon, 26 Jun 2023 19:52:46 -0400 Subject: [PATCH 26/26] Use tvm::attr::kGlobalSymbol instead of string literal --- src/script/printer/relax/function.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/script/printer/relax/function.cc b/src/script/printer/relax/function.cc index 86c5b62ab72f..bc5f12309f47 100644 --- a/src/script/printer/relax/function.cc +++ b/src/script/printer/relax/function.cc @@ -110,7 +110,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // if the function is global or is not in a module and does not have a global symbol, // indicate that it's private if (AtTopLevelFunction(d) && - (!n->attrs.defined() || !n->attrs->dict.count("global_symbol"))) { + (!n->attrs.defined() || !n->attrs->dict.count(tvm::attr::kGlobalSymbol))) { dec_keys.push_back("private"); dec_values.push_back(LiteralDoc::Boolean(true, Optional())); }