From ace3b16b823adf8fb575901f94bc5bd159b45e97 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 19 Aug 2024 15:17:24 -0500 Subject: [PATCH 1/6] [Relax] Require correct input/output shapes `R.call_tir` Prior to this commit, the Relax well-formed checker validated arguments provided to Relax functions, but did not validate arguments provided to `R.call_tir`. As a result, incorrect arguments from Relax to TIR would not be checked until runtime, if at all. This commit updates the well-formed checker to verify that `R.call_tir` has received the correct arguments, and has the correct output shape specified in the `out_sinfo` parameter. --- src/relax/op/op.cc | 224 ++++++++-- .../python/relax/test_analysis_well_formed.py | 388 +++++++++++++++++- 2 files changed, 585 insertions(+), 27 deletions(-) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 0a840248ffe8..3921f1df594e 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -242,15 +242,157 @@ TVM_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInpla // call_tir +/* If possible, infer a legal value of `arg_sinfo` + * + * The `R.call_tir` operator and its variants accept an `arg_sinfo` + * parameter, which specifies the shape of the tensor or tensors + * returned by a PrimFunc. This output shape must be compatible with + * the shape defined by the PrimFunc's signature. + * + * For dynamic shapes, it is not always possible to infer the output + * of a TIR PrimFunc from its inputs. For example, a PrimFunc that + * accepts input buffer `T.Buffer([16], "float32")` and output buffer + * `T.Buffer([M, N], "float32")` infers the values of `M` and `N` from + * the shape of the provided output buffer. + * + * If the arguments provided are not compatible with the PrimFunc's + * signature, an error will be raised. If the arguments are + * compatible with the PrimFunc's signature, but are not sufficient to + * determine the output's StructInfo, then `NullOpt` will be returned. + * + * \param func_sinfo The StructInfo of the TIR callee. + * \param arg_sinfo The StructInfo of the argument tuple. + * \param packed_ints_sinfo The StructInfo of the ShapeTuple argument, + * if present. + * + * \return The `arg_sinfo`, if it can be inferred from the arguments. + * Otherwise, NullOpt. + */ +static Optional InferCallTIROutputStructInfoFromArguments( + StructInfo func_sinfo, StructInfo arg_sinfo, Optional packed_ints_sinfo) { + auto opt_callee_sinfo = func_sinfo.as(); + CHECK(opt_callee_sinfo) << "TypeError: " + << "The first argument to `R.call_tir` must be a function, " + << "but instead received argument of type " << func_sinfo; + auto callee_sinfo = opt_callee_sinfo.value(); + + CHECK(callee_sinfo->params.defined()) + << "ValueError: " + << "The first argument to `R.call_tir` must be a function " + << "with known argument types. " + << "However, the first argument was of type " << callee_sinfo; + auto callee_params = callee_sinfo->params.value(); + + const TupleStructInfoNode* args = arg_sinfo.as(); + CHECK(args) << "TypeError: " + << "The second argument to `R.call_tir` must be a tuple, " + << "but instead received expression of type " << arg_sinfo; + + // R.call_tir expects the PrimFunc to have three groups of arguments. + // + // 1. Input arguments that are explicitly provided as Relax arguments. + // 2. Output tensor arguments. + // 3. Shape arguments, represented as `T.int64` in the PrimFunc, and + // as an optional ShapeExpr argument in the `relax::Call` node. + // + // In order to determine the return type of `R.call_tir`, we must + // identify the PrimFunc arguments that will be in group (2). + size_t num_input_arguments = args->fields.size(); + size_t num_trailing_int_arguments = 0; + const ShapeStructInfoNode* packed_tuple_sinfo = nullptr; + if (packed_ints_sinfo) { + auto packed_sinfo = packed_ints_sinfo.value(); + packed_tuple_sinfo = packed_sinfo.as(); + CHECK(packed_tuple_sinfo && !packed_tuple_sinfo->IsUnknownNdim()) + << "TypeError: " + << "The third argument to `R.call_tir`, if present, " + << "must be a ShapeTuple with known dimensionality. " + << "However, the argument received was of type " << packed_sinfo; + num_trailing_int_arguments = packed_tuple_sinfo->ndim; + } else { + num_trailing_int_arguments = 0; + } + + CHECK_LE(num_input_arguments + num_trailing_int_arguments, callee_params.size()) + << "ValueError: " + << "R.call_tir attempted to call a function using " << num_input_arguments + << " input arguments and " << num_trailing_int_arguments << " trailing integer arguments. " + << "However, the callee only accepts " << callee_params.size() << " arguments in total."; + + // At this point, the return types are known. However, the shapes + // in `callee_params` may contain dynamic shape parameters that are + // not present in the caller's scope. The `DeriveCallRetStructInfo` + // utility can infer the value of dynamic parameters in + // `FuncStructInfoNode::ret` based on definitions in + // `FuncStructInfoNode::params`, inferring the correct values in the + // caller's scope. + // + // Since the callee of `R.call_tir` is provided with output + // arguments, where `DeriveCallRetStructInfo` requires a callee that + // produces its own outputs, a dummy function signature and + // arguments are used. + + auto dummy_callee_sinfo = [&]() -> FuncStructInfo { + Array dummy_params(callee_params.begin(), + callee_params.begin() + num_input_arguments); + + for (size_t i = callee_params.size() - num_trailing_int_arguments; i < callee_params.size(); + i++) { + dummy_params.push_back(callee_params[i]); + } + + Array dummy_ret(callee_params.begin() + num_input_arguments, + callee_params.end() - num_trailing_int_arguments); + auto dummy_out_sinfo = [&]() -> StructInfo { + if (dummy_ret.size() == 1) { + return dummy_ret[0]; + } else { + return TupleStructInfo(dummy_ret); + } + }(); + + return FuncStructInfo(dummy_params, dummy_out_sinfo); + }(); + + auto dummy_args = [&]() -> Array { + Array dummy_args = args->fields.Map( + [](const StructInfo& sinfo) -> Expr { return Var("dummy_leading_arg", sinfo); }); + + for (size_t i = 0; i < num_trailing_int_arguments; i++) { + ICHECK(packed_tuple_sinfo); + PrimStructInfo dummy_arg_sinfo = [&]() { + if (packed_tuple_sinfo->values) { + return PrimStructInfo(packed_tuple_sinfo->values.value()[i]); + } else { + return PrimStructInfo(DataType::Int(64)); + } + }(); + dummy_args.push_back(Var("dummy_trailing_arg", dummy_arg_sinfo)); + } + + return dummy_args; + }(); + + auto derived_ret_sinfo = DeriveCallRetStructInfo( + dummy_callee_sinfo, Call(Var("dummy_callee", dummy_callee_sinfo), dummy_args), + BlockBuilder::Create(NullOpt)); + + return derived_ret_sinfo; +} + StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { if (call->sinfo_args.size() != 1) { ctx->ReportFatal(Diagnostic::Error(call) << "sinfo_args should have exactly 1 output struct info."); } CHECK(call->args[0]->IsInstance()) - << "call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " - << "However, gets " << call->args[0]; - return call->sinfo_args[0]; + << "R.call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " + << "However, the argument " << call->args[0] << " instead has type " + << call->args[0]->GetTypeKey(); + + StructInfo explicit_sinfo = call->sinfo_args[0]; + + return explicit_sinfo; } Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { @@ -264,23 +406,51 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { << "or three arguments [callee, arg_tuple, tir_args], " << "but " << call << " has " << call->args.size() << " arguments."; - Expr arg_expr = call->args[1]; + auto callee = call->args[0]; + CHECK(callee->struct_info_.as()) + << "Operation " << call->op << " expects the first argument to be a TIR callee. " + << "However, the first argument " << callee << " has struct info " << callee->struct_info_; + + Expr arg_tuple = call->args[1]; - CHECK(arg_expr->struct_info_.as()) + CHECK(arg_tuple->struct_info_.as()) << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " - << "However, the second argument " << arg_expr << " has struct info " - << arg_expr->struct_info_ << "."; - - if (arg_expr.as()) { - return std::move(call); - } + << "However, the second argument " << arg_tuple << " has struct info " + << arg_tuple->struct_info_ << "."; - CHECK(arg_expr.as()) + CHECK(arg_tuple.as() || arg_tuple.as()) << "Operation " << call->op << " must hold its arguments as an in-line tuple. " - << "However, " << call << " has arguments " << arg_expr + << "However, " << call << " has arguments " << arg_tuple << ", which is neither an in-line tuple, " << "nor a variable binding that may be normalized to an in-line tuple."; + auto packed_int_sinfo = [&]() -> Optional { + if (call->args.size() <= 2) { + return NullOpt; + } + + Expr packed_ints = call->args[2]; + CHECK(packed_ints->struct_info_.as()) + << "Operation " << call->op << " expects the optional third argument, " + << "if present, to be a ShapeTuple. " + << "However, the third argument " << packed_ints << " has struct info " + << packed_ints->struct_info_; + return GetStructInfo(packed_ints); + }(); + + CHECK_EQ(call->sinfo_args.size(), 1) + << "R.call_tir should have exactly one `sinfo_args` parameter, " + << "which defines the output of the PrimFunc."; + StructInfo explicit_sinfo = call->sinfo_args[0]; + if (auto inferred_sinfo = InferCallTIROutputStructInfoFromArguments( + GetStructInfo(callee), GetStructInfo(arg_tuple), packed_int_sinfo)) { + CHECK(IsBaseOf(inferred_sinfo.value(), explicit_sinfo)) + << "TypeError: " + << "The `out_sinfo` argument for R.call_tir must be compatible with the PrimFunc. " + << "However, the PrimFunc's signature implies that the output should be " << inferred_sinfo + << ", but the `out_sinfo` argument was " << explicit_sinfo; + } + auto unwrap_binding = [&ctx](Expr expr) -> Optional { if (auto var = expr.as()) { if (auto bound_value = ctx->LookupBinding(var.value())) { @@ -290,14 +460,20 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { return NullOpt; }; - while (auto unwrapped = unwrap_binding(arg_expr)) { - arg_expr = unwrapped.value(); - } + Tuple new_arg_tuple = [&]() { + // No replacement required. The argument tuple is already + // provided as an in-line tuple. + if (auto opt = arg_tuple.as()) { + return opt.value(); + } + + while (auto unwrapped = unwrap_binding(arg_tuple)) { + arg_tuple = unwrapped.value(); + } - Tuple new_arg_expr = [&]() { // Preferred replacement. The argument tuple is provided as a // variable, but we know the value bound to that variable. - if (auto opt = arg_expr.as()) { + if (auto opt = arg_tuple.as()) { return opt.value(); } @@ -306,16 +482,18 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { // example, if a relax function accepted a tuple as an parameter, // then provided that same tuple as an argument to call_tir. Array tuple_elements; - size_t num_fields = Downcast(arg_expr->struct_info_)->fields.size(); + size_t num_fields = Downcast(arg_tuple->struct_info_)->fields.size(); for (size_t i = 0; i < num_fields; i++) { - tuple_elements.push_back(TupleGetItem(arg_expr, i)); + tuple_elements.push_back(TupleGetItem(arg_tuple, i)); } return Tuple(tuple_elements); }(); - auto new_args = call->args; - new_args.Set(1, new_arg_expr); - call.CopyOnWrite()->args = new_args; + if (!new_arg_tuple.same_as(arg_tuple)) { + auto new_args = call->args; + new_args.Set(1, new_arg_tuple); + call.CopyOnWrite()->args = new_args; + } return std::move(call); } diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 7deddfd28eb9..0918b1d70b81 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -14,15 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import pytest + import tvm import tvm.testing + from tvm import relax as rx from tvm import tir -from tvm.script import relax as R -from tvm.script import ir as I -from tvm.script import tir as T -from tvm.script import ir as I +from tvm.script import ir as I, relax as R, tir as T m = tir.Var("m", "int64") n = tir.Var("n", "int64") @@ -702,5 +702,385 @@ def is_bfloat16_dtype(tensor: T.handle) -> T.bool: assert rx.analysis.well_formed(Module) +def test_call_tir_with_matching_arguments(): + """R.call_tir is well-formed when called with matching arguments""" + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert rx.analysis.well_formed(Module) + + +def test_call_tir_input_ndim(): + """Arguments to R.call_tir must have the correct dimensionality + + Here, the `add_one` function expects a 1-d input tensor, but is + called with a 2-d tensor. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([4, 4], "float16")) -> R.Prim("bool"): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_output_ndim(): + """Output shape R.call_tir must have the correct dimensionality + + Here, the `add_one` function requires a 1-d output tensor, but is + provided with a 2-d tensor. + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([4, 4], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_input_shape(): + """Arguments to R.call_tir must have the correct shape + + Here, the `add_one` function expects an input tensor with 16 + elements, but is called with an input tensor with 32 elements. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([32], "float16")) -> R.Prim("bool"): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_output_shape(): + """Output shape R.call_tir must have the correct shape + + Here, the `add_one` function requires an output tensor with 16 + elements, but is provided an output tensor with 32 elements. + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([32], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_input_dtype(): + """Arguments to R.call_tir must have the correct dtype + + Here, the `add_one` function expects an input tensor containing + float16 value, but is called with an input tensor containing + float32 values. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float32")) -> R.Prim("bool"): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_output_dtype(): + """Output shape R.call_tir must have the correct shape + + Here, the `add_one` function requires an output tensor that may be + populated with float16 values, but is provided an output tensor + that may be populated with float32 elements. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float32")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_with_correct_dynamic_output_shape(): + """Output shape R.call_tir may not be verifiable + + Here, the input arguments to the `reshape` function are not + sufficient to infer the shape of the outputs. This is legal, + since the output shape is determined by the `out_sinfo` parameter. + + Inability to verify the output shape does not mean that the output + shape is invalid. + + """ + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 8], "float16")) + return B + + @T.prim_func + def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle): + M = T.int64() + N = T.int64() + B = T.match_buffer(B_handle, [M, N], dtype="float16") + + for i, j in T.grid(M, N): + with T.block("compute"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi * N + vj] + + assert rx.analysis.well_formed(Module) + + +@pytest.mark.xfail(reason="Not supported") +def test_call_tir_with_incorrect_dynamic_output_shape(): + """Output shape R.call_tir may not be verifiable + + Here, the input arguments to the `reshape` function are not + sufficient to infer the shape of the outputs. Even though the + IRModule will not provide well-defined output due to the + out-of-bounds read from buffer A, catching this error is beyond + the current scope of the Relax well-formed checker. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([16, 16], "float16")) + return B + + @T.prim_func + def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle): + M = T.int64() + N = T.int64() + B = T.match_buffer(B_handle, [M, N], dtype="float16") + + for i, j in T.grid(M, N): + with T.block("compute"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi * N + vj] + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_incorrect_dimensionality_of_output_shape(): + """Dimensionality may be verified + + Here, the input arguments to the `reshape` function are not + sufficient to infer the shape of the outputs. + + Even though the output shape may not be inferred from the input + arguments, the output dimensionality can still be inferred from + the PrimFunc signature. The IRModule below is ill-formed, because + the PrimFunc requires a 2-d output argument, but is provided with + a 3-d output argument. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 4, 2], "float16")) + return B + + @T.prim_func + def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle): + M = T.int64() + N = T.int64() + B = T.match_buffer(B_handle, [M, N], dtype="float16") + + for i, j in T.grid(M, N): + with T.block("compute"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi * N + vj] + + assert not rx.analysis.well_formed(Module) + + +@pytest.mark.xfail(reason="Not yet supported") +def test_call_tir_output_shape_with_mixed_static_and_dynamic(): + """Some dimensions of the R.call_tir output shape may be verifiable + + Here, the input arguments to the `reshape` function are not + sufficient to infer the shape of the outputs. This is legal, + since the output shape is taken from the `out_sinfo` parameter. + + Identifying this failure mode is not yet supported in the current + implementation. This is because the output is inferred as + `R.Tensor(ndim=3, dtype="float16")`, and the explicit `out_sinfo` + is a 3-d tensor. The mismatch in the first dimension is not yet + counted, because the entire tensor shape is removed by + `EraseToWellDefined`. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([256], "float16")) -> R.Prim("bool"): + B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([8, 16, 2], "float16")) + return B + + @T.prim_func + def reshape(A: T.Buffer(256, "float16"), B_handle: T.handle): + M = T.int64() + N = T.int64() + B = T.match_buffer(B_handle, [16, M, N], dtype="float16") + + for i, j, k in T.grid(16, M, N): + with T.block("compute"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi * N * M + vj * N + vk] + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_with_correct_inferred_dynamic_output_shape(): + """Some dynamic output shapes of R.call_tir may be inferred + + Here, the `flatten` function is dynamic, and will flatten any 2-d + TIR buffer. Even though it is dynamic, the input shapes are + sufficient to infer that `M==8` and `N==4`. As a result, the + output shape of `[M*N]` can be inferred to be `[32]`, and the + shape specified in `out_sinfo` can be validated. + + """ + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([8, 4], "float16")) -> R.Prim("bool"): + B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([32], "float16")) + return B + + @T.prim_func + def flatten(A_handle: T.handle, B_handle: T.handle): + M = T.int64() + N = T.int64() + A = T.match_buffer(A_handle, [M, N], dtype="float16") + B = T.match_buffer(B_handle, [M * N], dtype="float16") + + for i in T.grid(M * N): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // N, vi % N] + + assert rx.analysis.well_formed(Module) + + +def test_call_tir_with_incorrect_inferred_dynamic_output_shape(): + """Some dynamic output shapes of R.call_tir may be inferred + + Here, the `flatten` function is dynamic, and will flatten any 2-d + TIR buffer. Even though it is dynamic, the input shapes are + sufficient to infer that `M==8` and `N==4`. As a result, the + output shape of `[M*N]` can be inferred to be `[32]`, and the + shape specified in `out_sinfo` can be validated. + + This unit test is identical to the above test + `test_call_tir_with_correct_inferred_dynamic_output_shape`, except + that the output shape is explicitly specified as `[64]`, which is + caught as a mismatch from the expected output shape. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([8, 4], "float16")) -> R.Prim("bool"): + B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([64], "float16")) + return B + + @T.prim_func + def flatten(A_handle: T.handle, B_handle: T.handle): + M = T.int64() + N = T.int64() + A = T.match_buffer(A_handle, [M, N], dtype="float16") + B = T.match_buffer(B_handle, [M * N], dtype="float16") + + for i in T.grid(M * N): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // N, vi % N] + + assert not rx.analysis.well_formed(Module) + + if __name__ == "__main__": tvm.testing.main() From a70cdbe1fe4d515f05b64c46419f8b2f462d2887 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 20 Aug 2024 10:51:35 -0500 Subject: [PATCH 2/6] Update to handle R.dist.call_tir --- src/relax/op/op.cc | 18 +++++ .../python/relax/test_analysis_well_formed.py | 65 +++++++++++++++---- tests/python/relax/test_dataflow_pattern.py | 2 +- 3 files changed, 71 insertions(+), 14 deletions(-) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 3921f1df594e..77ec650fa299 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include #include #include @@ -319,6 +320,23 @@ static Optional InferCallTIROutputStructInfoFromArguments( << " input arguments and " << num_trailing_int_arguments << " trailing integer arguments. " << "However, the callee only accepts " << callee_params.size() << " arguments in total."; + // While Relax can specify a distributed tensor, TIR cannot. The + // current implementation does not support determining the output + // shape for `R.dist.call_tir` calls, as it depends on the lowering + // of DistIR into regular Relax. + std::function contains_dtensor = [&contains_dtensor](StructInfo sinfo) -> bool { + if (sinfo.as()) { + return true; + } else if (auto tuple = sinfo.as()) { + return std::any_of(tuple->fields.begin(), tuple->fields.end(), contains_dtensor); + } else { + return false; + } + }; + if (contains_dtensor(arg_sinfo)) { + return NullOpt; + } + // At this point, the return types are known. However, the shapes // in `callee_params` may contain dynamic shape parameters that are // not present in the caller's scope. The `DeriveCallRetStructInfo` diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 0918b1d70b81..1e9e2a649113 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -708,7 +708,7 @@ def test_call_tir_with_matching_arguments(): @I.ir_module class Module: @R.function - def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + def main(A: R.Tensor([16], "float16")): B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) return B @@ -733,7 +733,7 @@ def test_call_tir_input_ndim(): @I.ir_module(check_well_formed=False) class Module: @R.function - def main(A: R.Tensor([4, 4], "float16")) -> R.Prim("bool"): + def main(A: R.Tensor([4, 4], "float16")): B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) return B @@ -757,7 +757,7 @@ def test_call_tir_output_ndim(): @I.ir_module(check_well_formed=False) class Module: @R.function - def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + def main(A: R.Tensor([16], "float16")): B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([4, 4], "float16")) return B @@ -782,7 +782,7 @@ def test_call_tir_input_shape(): @I.ir_module(check_well_formed=False) class Module: @R.function - def main(A: R.Tensor([32], "float16")) -> R.Prim("bool"): + def main(A: R.Tensor([32], "float16")): B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) return B @@ -806,7 +806,7 @@ def test_call_tir_output_shape(): @I.ir_module(check_well_formed=False) class Module: @R.function - def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + def main(A: R.Tensor([16], "float16")): B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([32], "float16")) return B @@ -832,7 +832,7 @@ def test_call_tir_input_dtype(): @I.ir_module(check_well_formed=False) class Module: @R.function - def main(A: R.Tensor([16], "float32")) -> R.Prim("bool"): + def main(A: R.Tensor([16], "float32")): B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) return B @@ -858,7 +858,7 @@ def test_call_tir_output_dtype(): @I.ir_module(check_well_formed=False) class Module: @R.function - def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + def main(A: R.Tensor([16], "float16")): B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float32")) return B @@ -887,7 +887,7 @@ def test_call_tir_with_correct_dynamic_output_shape(): @I.ir_module class Module: @R.function - def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + def main(A: R.Tensor([16], "float16")): B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 8], "float16")) return B @@ -920,7 +920,7 @@ def test_call_tir_with_incorrect_dynamic_output_shape(): @I.ir_module(check_well_formed=False) class Module: @R.function - def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + def main(A: R.Tensor([16], "float16")): B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([16, 16], "float16")) return B @@ -955,7 +955,7 @@ def test_call_tir_incorrect_dimensionality_of_output_shape(): @I.ir_module(check_well_formed=False) class Module: @R.function - def main(A: R.Tensor([16], "float16")) -> R.Prim("bool"): + def main(A: R.Tensor([16], "float16")): B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 4, 2], "float16")) return B @@ -993,7 +993,7 @@ def test_call_tir_output_shape_with_mixed_static_and_dynamic(): @I.ir_module(check_well_formed=False) class Module: @R.function - def main(A: R.Tensor([256], "float16")) -> R.Prim("bool"): + def main(A: R.Tensor([256], "float16")): B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([8, 16, 2], "float16")) return B @@ -1025,7 +1025,7 @@ def test_call_tir_with_correct_inferred_dynamic_output_shape(): @I.ir_module class Module: @R.function - def main(A: R.Tensor([8, 4], "float16")) -> R.Prim("bool"): + def main(A: R.Tensor([8, 4], "float16")): B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([32], "float16")) return B @@ -1063,7 +1063,7 @@ def test_call_tir_with_incorrect_inferred_dynamic_output_shape(): @I.ir_module(check_well_formed=False) class Module: @R.function - def main(A: R.Tensor([8, 4], "float16")) -> R.Prim("bool"): + def main(A: R.Tensor([8, 4], "float16")): B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([64], "float16")) return B @@ -1082,5 +1082,44 @@ def flatten(A_handle: T.handle, B_handle: T.handle): assert not rx.analysis.well_formed(Module) +def test_call_tir_with_dtensor_arguments(): + """R.call_tir and R.dist.call_tir share the same operation + + Both `R.call_tir` and `R.dist.call_tir` produce the same + "relax.call_tir" operation, differing only in the StructInfo of + their arguments. Normalization of "relax.call_tir" must handle + `R.DTensor` arguments. + + """ + + # from tvm.script.parser import relax as R + + @I.ir_module + class Module: + I.module_attrs({"device_num": 4}) + I.module_global_infos({"mesh": [R.dist.device_mesh([4], I.Range(0, 4))]}) + + @R.function + def main(A: R.dist.DTensor([8, 4], "float16", "mesh[0]", "S[0]")): + B = R.dist.call_tir( + Module.flatten, A, out_sinfo=R.dist.DTensor([64], "float16", "mesh[0]", "S[0]") + ) + return B + + @T.prim_func + def flatten(A_handle: T.handle, B_handle: T.handle): + M = T.int64() + N = T.int64() + A = T.match_buffer(A_handle, [M, N], dtype="float16") + B = T.match_buffer(B_handle, [M * N], dtype="float16") + + for i in T.grid(M * N): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // N, vi % N] + + assert rx.analysis.well_formed(Module) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 03a3beb2f27e..7a3b65cea10e 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -72,7 +72,7 @@ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) lv2 = R.call_tir( - cls.tir_zeros, (lv1), R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32]) + cls.tir_zeros, [], R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32]) ) gv = (lv1, lv2) R.output(gv) From 006bb1e6ea3a80365c24ed82875b94ae7b501001 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 21 Aug 2024 12:26:52 -0500 Subject: [PATCH 3/6] Update unit tests that were caught by the new well-formed check --- src/relax/op/op.cc | 43 +++++++-- src/relax/transform/fuse_tir.cc | 3 +- ...istributed_transform_propagate_sharding.py | 8 -- .../python/relax/test_analysis_well_formed.py | 87 +++++++++++++++++++ tests/python/relax/test_ast_printer.py | 9 +- tests/python/relax/test_dataflow_inplace.py | 16 ++-- tests/python/relax/test_frontend_dynamo.py | 41 +++++---- tests/python/relax/test_frontend_nn_op.py | 18 +++- tests/python/relax/test_transform.py | 6 +- .../test_transform_dead_code_elimination.py | 60 +++++++------ tests/python/relax/test_transform_fuse_ops.py | 14 +-- .../test_transform_fuse_ops_by_pattern.py | 78 ++++++++--------- .../test_transform_lazy_transform_params.py | 26 +++--- ...test_transform_rewrite_dataflow_reshape.py | 37 ++++---- tests/python/relax/test_tvmscript_parser.py | 55 ++++++------ tests/python/relax/test_vm_build.py | 28 +++--- 16 files changed, 337 insertions(+), 192 deletions(-) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 77ec650fa299..2e42e1a06167 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -265,12 +265,17 @@ TVM_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInpla * \param arg_sinfo The StructInfo of the argument tuple. * \param packed_ints_sinfo The StructInfo of the ShapeTuple argument, * if present. + * \param opt_inplace_indices For `R.call_tir_inplace`, an array of + * indices indicating which outputs are constructed from in-place + * mutation of the inputs. See + * `CallTIRInplaceAttrs::inplace_indices` for more details. * * \return The `arg_sinfo`, if it can be inferred from the arguments. * Otherwise, NullOpt. */ static Optional InferCallTIROutputStructInfoFromArguments( - StructInfo func_sinfo, StructInfo arg_sinfo, Optional packed_ints_sinfo) { + StructInfo func_sinfo, StructInfo arg_sinfo, Optional packed_ints_sinfo, + Optional> opt_inplace_indices) { auto opt_callee_sinfo = func_sinfo.as(); CHECK(opt_callee_sinfo) << "TypeError: " << "The first argument to `R.call_tir` must be a function, " @@ -361,6 +366,22 @@ static Optional InferCallTIROutputStructInfoFromArguments( Array dummy_ret(callee_params.begin() + num_input_arguments, callee_params.end() - num_trailing_int_arguments); + + if (opt_inplace_indices) { + // For R.call_tir_inplace, the `inplace_indices` are used to + // indicate which elements of the `out_sinfo` will be generated + // as in-place mutation from an input. For any in-place + // mutation, the parameter's StructInfo must be inserted into + // `out_sinfo`. + auto inplace_indices = opt_inplace_indices.value(); + for (size_t i = 0; i < inplace_indices.size(); i++) { + auto inplace_input_index = inplace_indices[i]->value; + if (inplace_input_index >= 0) { + dummy_ret.insert(dummy_ret.begin() + i, callee_params[inplace_input_index]); + } + } + } + auto dummy_out_sinfo = [&]() -> StructInfo { if (dummy_ret.size() == 1) { return dummy_ret[0]; @@ -456,12 +477,21 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { return GetStructInfo(packed_ints); }(); + auto opt_inplace_indices = [&]() -> Optional> { + if (const auto* attrs = call->attrs.as()) { + return attrs->inplace_indices; + } else { + return NullOpt; + } + }(); + CHECK_EQ(call->sinfo_args.size(), 1) << "R.call_tir should have exactly one `sinfo_args` parameter, " << "which defines the output of the PrimFunc."; StructInfo explicit_sinfo = call->sinfo_args[0]; - if (auto inferred_sinfo = InferCallTIROutputStructInfoFromArguments( - GetStructInfo(callee), GetStructInfo(arg_tuple), packed_int_sinfo)) { + auto inferred_sinfo = InferCallTIROutputStructInfoFromArguments( + GetStructInfo(callee), GetStructInfo(arg_tuple), packed_int_sinfo, opt_inplace_indices); + if (inferred_sinfo.defined()) { CHECK(IsBaseOf(inferred_sinfo.value(), explicit_sinfo)) << "TypeError: " << "The `out_sinfo` argument for R.call_tir must be compatible with the PrimFunc. " @@ -485,13 +515,14 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { return opt.value(); } - while (auto unwrapped = unwrap_binding(arg_tuple)) { - arg_tuple = unwrapped.value(); + Expr unwrapped_tuple = arg_tuple; + while (auto unwrapped = unwrap_binding(unwrapped_tuple)) { + unwrapped_tuple = unwrapped.value(); } // Preferred replacement. The argument tuple is provided as a // variable, but we know the value bound to that variable. - if (auto opt = arg_tuple.as()) { + if (auto opt = unwrapped_tuple.as()) { return opt.value(); } diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index b203b322ab96..612e1459c826 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -1088,8 +1088,7 @@ class TIRFuseMutator : public ExprMutator { const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, old_gvar); GlobalVar new_gvar(old_gvar->name_hint); - UpdateStructInfo(new_gvar, - FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type))); + UpdateStructInfo(new_gvar, GetStructInfo(prim_func)); mod->Remove(old_gvar); updates->Add(new_gvar, prim_func); diff --git a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py index e1f45d278d6c..865051b0b4b9 100644 --- a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py +++ b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py @@ -512,13 +512,11 @@ def foo( cls.rotary_embedding, (lv9, cos_cached, sin_cached), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), - tir_vars=R.shape([256]), ) lv17 = R.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), - tir_vars=R.shape([256]), ) lv18: R.Tensor((256, 32, 128), dtype="float16") = R.reshape( lv17, R.shape([256, 32, 128]) @@ -712,13 +710,11 @@ def foo( cls.rotary_embedding, (lv9, cos_cached, sin_cached), out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), - tir_vars=R.shape([256]), ) lv17 = R.dist.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), - tir_vars=R.shape([256]), ) lv18: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.reshape( lv17, R.shape([256, 32, 128]) @@ -1278,13 +1274,11 @@ def foo( cls.rotary_embedding, (lv9, cos_cached, sin_cached), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), - tir_vars=R.shape([256]), ) lv17 = R.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), - tir_vars=R.shape([256]), ) lv18 = R.call_tir( cls.reshape1, (lv17,), out_sinfo=R.Tensor((256, 32, 128), dtype="float16") @@ -1449,13 +1443,11 @@ def foo( LlamaAttentionLayerTIR.get_global_var("rotary_embedding"), (lv9, cos_cached, sin_cached), out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), - tir_vars=R.shape([256]), ) lv17 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("rotary_embedding"), (lv12, cos_cached, sin_cached), out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), - tir_vars=R.shape([256]), ) lv18 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("reshape1"), diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 1e9e2a649113..c0b962c3f3a0 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -1121,5 +1121,92 @@ def flatten(A_handle: T.handle, B_handle: T.handle): assert rx.analysis.well_formed(Module) +def test_call_tir_inplace_with_correct_shapes(): + """R.call_tir_inplace is well-formed when called with matching arguments""" + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir_inplace( + Module.add_one, + A, + inplace_indices=[0], + out_sinfo=R.Tensor([16], "float16"), + ) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + A[vi] = A[vi] + T.float16(1.0) + + assert rx.analysis.well_formed(Module) + + +def test_call_tir_inplace_with_incorrect_shapes(): + """R.call_tir_inplace is ill-formed when output shape does not match input""" + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir_inplace( + Module.add_one, + A, + inplace_indices=[0], + out_sinfo=R.Tensor([32], "float16"), + ) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + A[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_inplace_with_some_allocated_outputs(): + """R.call_tir_inplace may contain some non-inplace outputs""" + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16"), B: R.Tensor([32], "float16")): + out = R.call_tir_inplace( + Module.add_one, + (A, B), + inplace_indices=[-1, 1], + out_sinfo=[ + R.Tensor([16], "float16"), + R.Tensor([32], "float16"), + ], + ) + return out + + @T.prim_func + def add_one( + A: T.Buffer(16, "float16"), + B: T.Buffer(32, "float16"), + C: T.Buffer(16, "float16"), + ): + for i in range(32): + with T.block("inplace_B"): + vi = T.axis.remap("S", [i]) + B[vi] = B[vi] + T.float16(1.0) + + for i in range(16): + with T.block("output_C"): + vi = T.axis.remap("S", [i]) + C[vi] = A[vi] + T.float16(1.0) + + assert rx.analysis.well_formed(Module) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 64d5c7381171..6005ecb0fa58 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -43,6 +43,7 @@ def normalize(func: rx.Function) -> rx.Function: """ Normalize the expr to fill in the checked_type_ and struct_info fields everywhere """ + # using a default mutator to use the BlockBuilder's normalizer, # which oddly differs from the Normalize pass @rx.expr_functor.mutator @@ -435,9 +436,13 @@ def test_call_tir(): @tvm.script.ir_module class TestCallTIR: @T.prim_func - def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: + def addone(A_handle: T.handle, B_handle: T.handle) -> None: + m = T.int64() + n = T.int64() + A = T.match_buffer(A_handle, (m, n), "float32") + B = T.match_buffer(B_handle, (m, n), "float32") T.func_attr(({"global_symbol": "addone"})) - for i, j in T.grid(16, 16): + for i, j in T.grid(m, n): with T.block("addone"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.int32(1) diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index 8d5eb07c7858..a127b0fa263f 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -172,8 +172,8 @@ def tir_id(x: T.handle, y: T.handle) -> None: T.func_attr({"global_symbol": "tir_id"}) m = T.int32() n = T.int32() - A = T.match_buffer(x, (m, n)) - B = T.match_buffer(y, (m, n)) + A = T.match_buffer(x, (m, n), "int32") + B = T.match_buffer(y, (m, n), "int32") for i, j in T.grid(m, n): with T.block("id"): @@ -185,9 +185,9 @@ def tir_id2(x: T.handle, y: T.handle, z: T.handle) -> None: T.func_attr({"global_symbol": "tir_id"}) m = T.int32() n = T.int32() - A = T.match_buffer(x, (m, n)) - B = T.match_buffer(y, (m, n)) - C = T.match_buffer(z, (m, n)) + A = T.match_buffer(x, (m, n), "int32") + B = T.match_buffer(y, (m, n), "int32") + C = T.match_buffer(z, (m, n), "int32") for i, j in T.grid(m, n): with T.block("id"): @@ -323,9 +323,9 @@ def test_inplace_simple_case(): @I.ir_module class InplaceBasic: @R.function - def main( - x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") - ) -> R.Tensor((2, 3), "int32"): + def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")) -> R.Tensor( + (2, 3), "int32" + ): with R.dataflow(): z = R.add(x, y) # cannot be done inplace: x and y are live later p = R.add(z, z) # can be done inplace: z is not used later diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index d83f83f4e188..ed9f628aea26 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -114,9 +114,10 @@ def main( with db: opt_model = torch.compile(model, backend=relax_dynamo()) inp = torch.randn(10, 100) - tvm.testing.assert_allclose( - opt_model(inp).detach().numpy(), model(inp).detach().numpy(), rtol=1e-5, atol=1e-5 - ) + + default_output = model(inp).detach().numpy() + optimized_output = opt_model(inp).detach().numpy() + tvm.testing.assert_allclose(optimized_output, default_output, rtol=1e-5, atol=1e-5) def test_relax_dynamo_dynamic(): @@ -313,9 +314,9 @@ def forward(self, input): @I.ir_module class Expected1: @R.function - def main( - inp_0: R.Tensor((256, 256), dtype="float32") - ) -> R.Tensor((10, 10), dtype="float32"): + def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor( + (10, 10), dtype="float32" + ): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.full( R.shape([10, 10]), R.const(1, "float32"), dtype="float32" @@ -344,9 +345,9 @@ def forward(self, input): @I.ir_module class Expected1: @R.function - def main( - inp_0: R.Tensor((256, 256), dtype="float32") - ) -> R.Tensor((10, 10), dtype="float32"): + def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor( + (10, 10), dtype="float32" + ): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.full( R.shape([10, 10]), R.const(1, "float32"), dtype="float32" @@ -379,9 +380,9 @@ def forward(self, input): @I.ir_module class ExpectedGeLU: @R.function - def main( - inp_0: R.Tensor((128, 256), dtype="float32") - ) -> R.Tensor((128, 256), dtype="float32"): + def main(inp_0: R.Tensor((128, 256), dtype="float32")) -> R.Tensor( + (128, 256), dtype="float32" + ): with R.dataflow(): lv: R.Tensor((128, 256), dtype="float32") = R.nn.gelu(inp_0) gv: R.Tensor((128, 256), dtype="float32") = lv @@ -391,9 +392,9 @@ def main( @I.ir_module class ExpectedGeLUTanh: @R.function - def main( - inp_0: R.Tensor((128, 256), dtype="float32") - ) -> R.Tensor((128, 256), dtype="float32"): + def main(inp_0: R.Tensor((128, 256), dtype="float32")) -> R.Tensor( + (128, 256), dtype="float32" + ): with R.dataflow(): lv: R.Tensor((128, 256), dtype="float32") = R.nn.gelu_tanh(inp_0) gv: R.Tensor((128, 256), dtype="float32") = lv @@ -488,9 +489,9 @@ def main( @I.ir_module class Expected2: @R.function - def main( - inp_0: R.Tensor((1, 77, 1280), dtype="float32") - ) -> R.Tensor((1, 77, 1280), dtype="float32"): + def main(inp_0: R.Tensor((1, 77, 1280), dtype="float32")) -> R.Tensor( + (1, 77, 1280), dtype="float32" + ): with R.dataflow(): lv: R.Tensor((1,), dtype="int64") = R.arange( R.prim_value(0), R.prim_value(1), R.prim_value(1), dtype="int64" @@ -513,9 +514,7 @@ def main( class Select2(Module): def forward(self, input1): - result = input1[ - torch.arange(1), - ] + result = input1[torch.arange(1),] return result verify_dynamo_model( diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 6c3269195498..4990c80e895e 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -570,10 +570,18 @@ def test_tensor_ir_op(): @T.prim_func(private=True) def fused_rope( # pylint: disable=too-many-locals var_qkv: T.handle, - offset: T.int64, var_q: T.handle, var_k: T.handle, var_v: T.handle, + # Scalar arguments must be specified after tensor arguments, + # including the output tensor arguments + # + # TODO(Lunderberg): Update + # `tvm.relax.frontend.nn.op.tensor_ir_op` to use `PrimValue` + # instead of `tir_vars`, so that the order can be consistent + # between the function definition and the arguments in + # `op.tensor_ir_op`. + offset: T.int64, ): batch_size = T.int64() seq_len = T.int64() @@ -601,7 +609,7 @@ def test(self, qkv: Tensor, offset: tir.Var): @I.ir_module class Expected: @T.prim_func(private=True) - def llama_fused_rope(var_qkv: T.handle, offset: T.int64, var_q: T.handle, var_k: T.handle, var_v: T.handle): + def llama_fused_rope(var_qkv: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, offset: T.int64): batch_size, seq_len = T.int64(), T.int64() qkv = T.match_buffer(var_qkv, (batch_size, seq_len, 24, 16), "float16") q = T.match_buffer(var_q, (batch_size, seq_len, 8, 16), "float16") @@ -669,10 +677,11 @@ class Model(Module): def test( self, embedding_table: Tensor, input_ids: Tensor, embedding_dst: Tensor, offset: int ): - tensor_expr_op_out = op.tensor_ir_op( + tensor_expr_op_out = op.tensor_ir_inplace_op( inplace_take, "inplace_take", args=[embedding_table, input_ids, embedding_dst, offset], + inplace_indices=[2], out=Tensor.placeholder(embedding_dst.shape, embedding_dst.dtype), ) return tensor_expr_op_out @@ -719,10 +728,11 @@ def test( R.func_attr({"num_input": 4}) cls = Expected with R.dataflow(): - lv1 = R.call_tir( + lv1 = R.call_tir_inplace( cls.inplace_take, (embedding_table, input_ids, embedding_dst), out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype), + inplace_indices=[2], tir_vars=R.shape([offset_1]), ) gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1 diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index ee2df866fb35..e3274aea886a 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -86,7 +86,11 @@ def test_call_tir_rewrite(): @tvm.script.ir_module class TestCallTIRRewrite: @T.prim_func - def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + def exp(A_handle: T.handle, B_handle: T.handle): + m = T.int64() + n = T.int64() + A = T.match_buffer(A_handle, (m, n), "float32") + B = T.match_buffer(B_handle, (m, n), "float32") T.evaluate(0) @R.function diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 142faf51607b..5e3d50b9e1f5 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -176,9 +176,9 @@ def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32 return gv0 @R.function - def main( - x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") - ) -> R.Tensor((16, 16), "float32"): + def main(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")) -> R.Tensor( + (16, 16), "float32" + ): gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((16, 16), dtype="float32")) return gv0 @@ -213,9 +213,9 @@ def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32 return gv0 @R.function - def foo( - x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") - ) -> R.Tensor((16, 16), "float32"): + def foo(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")) -> R.Tensor( + (16, 16), "float32" + ): gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((16, 16), dtype="float32")) return gv0 @@ -254,9 +254,9 @@ def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32 return gv0 @R.function - def foo( - x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") - ) -> R.Tensor((16, 16), "float32"): + def foo(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")) -> R.Tensor( + (16, 16), "float32" + ): gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((16, 16), dtype="float32")) return gv0 @@ -277,18 +277,26 @@ def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), "float32"): def test_unused_relax_func_symbolic_shape(): # Test with relax function w/ symbolic shape. - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class InputModule: @T.prim_func - def tir_add( - x: T.Buffer((16, 16), "float32"), - y: T.Buffer((16, 16), "float32"), - z: T.Buffer((16, 16), "float32"), + def tir_matmul( + x_handle: T.handle, + y_handle: T.handle, + z_handle: T.handle, ) -> None: - for i, j in T.grid(16, 16): - with T.block("add"): - vi, vj = T.axis.remap("SS", [i, j]) - z[vi, vj] = x[vi, vj] + y[vi, vj] + m = T.int64() + n = T.int64() + k = T.int64() + x = T.match_buffer(x_handle, (m, n), "float32") + y = T.match_buffer(y_handle, (n, k), "float32") + z = T.match_buffer(z_handle, (m, k), "float32") + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + z[vi, vj] = 0.0 + z[vi, vj] = z[vi, vj] + x[vi, vk] * y[vk, vj] @R.function(private=True) def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): @@ -298,7 +306,7 @@ def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "flo @R.function def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): m, k = T.int64(), T.int64() - gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((m + 1, k), dtype="float32")) + gv0 = R.call_tir(InputModule.tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) return gv0 mod = InputModule @@ -306,7 +314,7 @@ def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")) new_mod = DeadCodeElimination()(mod) assert check_if_func_exists(new_mod, "main") - assert check_if_func_exists(new_mod, "tir_add") + assert check_if_func_exists(new_mod, "tir_matmul") assert not check_if_func_exists(new_mod, "unused_func") @@ -331,9 +339,9 @@ def relax_add(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") return gv0 @R.function - def main( - x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") - ) -> R.Tensor((16, 16), "float32"): + def main(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")) -> R.Tensor( + (16, 16), "float32" + ): gv0 = InputModule.relax_add(x, w) return gv0 @@ -367,9 +375,9 @@ def unused_func2(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float3 return gv0 @R.function - def main( - x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") - ) -> R.Tensor((16, 16), "float32"): + def main(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")) -> R.Tensor( + (16, 16), "float32" + ): gv0 = R.add(x, w) return gv0 diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 17bf58613294..f751e1b1a7df 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -875,7 +875,7 @@ class Module: def main(x: R.Tensor((1, 512, 64, 64), "float32"), mean: R.Tensor((64, 64), "float32"), var: R.Tensor((64, 64), "float32")): cls = Module with R.dataflow(): - gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64))) + gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64), 'float32')) gv1 = R.call_tir(cls.relu, gv0, out_sinfo=R.Tensor((1, 512, 64, 64), "float32")) R.output(gv1) return gv1 @@ -955,7 +955,7 @@ def fused_layer_norm_relu(x: R.Tensor((1, 512, 64, 64), dtype="float32"), mean: R.func_attr({"Primitive": 1}) cls = Expected with R.dataflow(): - gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64))) + gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64), 'float32')) gv = R.call_tir(cls.relu, (gv0,), out_sinfo=R.Tensor((1, 512, 64, 64), dtype="float32")) R.output(gv) return gv @@ -1452,7 +1452,7 @@ def main( R.Tensor((2,), "float32"), R.Tensor((2,), "float32"), R.Tensor((2,), "float32"), - ) + ), ): with R.dataflow(): x0 = x[0] @@ -1486,7 +1486,7 @@ def main( R.Tensor((2,), dtype="float32"), R.Tensor((2,), dtype="float32"), R.Tensor((2,), dtype="float32"), - ) + ), ) -> R.Tensor((2,), dtype="float32"): cls = Expected with R.dataflow(): @@ -1633,9 +1633,9 @@ def main( ) -> R.Tensor((10, 20), dtype="float32"): cls = Expected with R.dataflow(): - gv1: R.Tensor( - (10, 20), dtype="float32" - ) = cls.fused_add_exp_inplace_squeeze_inplace(x, p0) + gv1: R.Tensor((10, 20), dtype="float32") = ( + cls.fused_add_exp_inplace_squeeze_inplace(x, p0) + ) R.output(gv1) return gv1 diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 1582526042f1..0798ca462bf9 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -56,9 +56,9 @@ def main( ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): cls = Conv2dReLU_composite_annotated with R.dataflow(): - gv: R.Tensor( - (1, 64, 56, 56), dtype="float32" - ) = cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1) + gv: R.Tensor((1, 64, 56, 56), dtype="float32") = ( + cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1) + ) R.output(gv) return gv @@ -120,12 +120,12 @@ def main( ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): cls = Conv2dReLUx2Partitioned with R.dataflow(): - lv: R.Tensor( - (1, 64, 56, 56), dtype="float32" - ) = cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1) - gv: R.Tensor( - (1, 64, 54, 54), dtype="float32" - ) = cls.fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2) + lv: R.Tensor((1, 64, 56, 56), dtype="float32") = ( + cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1) + ) + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = ( + cls.fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2) + ) R.output(gv) return gv @@ -235,9 +235,9 @@ def main( lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d( data, weight1 ) - gv: R.Tensor( - (1, 64, 54, 54), dtype="float32" - ) = cls.fused_relax_nn_conv2d_relax_nn_relu(lv, weight2) + gv: R.Tensor((1, 64, 54, 54), dtype="float32") = ( + cls.fused_relax_nn_conv2d_relax_nn_relu(lv, weight2) + ) R.output(gv) return gv @@ -696,10 +696,10 @@ def test_ignore_call_tir(): class Conv2dReLUCallTIR: @T.prim_func def relu( - data: T.Buffer((64, 64, 56, 56), "float32"), - out: T.Buffer((64, 64, 56, 56), "float32"), + data: T.Buffer((1, 64, 56, 56), "float32"), + out: T.Buffer((1, 64, 56, 56), "float32"), ): - for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56): + for ax0, ax1, ax2, ax3 in T.grid(1, 64, 56, 56): with T.block("root"): i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) out[i, j, k, l] = T.max(data[i, j, k, l], 0.0) @@ -714,7 +714,7 @@ def main( relu1 = R.call_tir( Conv2dReLUCallTIR.relu, (conv1,), - R.Tensor((64, 64, 56, 56), "float32"), + R.Tensor((1, 64, 56, 56), "float32"), ) R.output(relu1) @@ -724,11 +724,11 @@ def main( class Conv2dReLUCallTIR_partitioned: @T.prim_func def relu( - data: T.Buffer((64, 64, 56, 56), "float32"), - out: T.Buffer((64, 64, 56, 56), "float32"), + data: T.Buffer((1, 64, 56, 56), "float32"), + out: T.Buffer((1, 64, 56, 56), "float32"), ): # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56): + for ax0, ax1, ax2, ax3 in T.grid(1, 64, 56, 56): with T.block("root"): i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(data[i, j, k, l]) @@ -754,7 +754,7 @@ def fused_relax_nn_conv2d( def main( data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), - ) -> R.Tensor((64, 64, 56, 56), dtype="float32"): + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): cls = Conv2dReLUCallTIR_partitioned with R.dataflow(): lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d( @@ -763,7 +763,7 @@ def main( relu1 = R.call_tir( cls.relu, (lv,), - out_sinfo=R.Tensor((64, 64, 56, 56), dtype="float32"), + out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"), ) R.output(relu1) return relu1 @@ -903,9 +903,9 @@ def func(inp: R.Tensor((16, 32), "float32")): @tvm.script.ir_module class Expected1: @R.function(private=True) - def fused_relax_split( - inp: R.Tensor((16, 32), dtype="float32") - ) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), dtype="float32")): + def fused_relax_split(inp: R.Tensor((16, 32), dtype="float32")) -> R.Tuple( + R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), dtype="float32") + ): R.func_attr({"Composite": "x.split", "Primitive": 1}) with R.dataflow(): gv: R.Tuple( @@ -932,9 +932,9 @@ def main(inp: R.Tensor((16, 32), dtype="float32")) -> R.Tensor((16, 16), dtype=" @I.ir_module class Expected2: @R.function(private=True) - def fused_relax_split_relax_add( - inp: R.Tensor((16, 32), dtype="float32") - ) -> R.Tensor((16, 16), dtype="float32"): + def fused_relax_split_relax_add(inp: R.Tensor((16, 32), dtype="float32")) -> R.Tensor( + (16, 16), dtype="float32" + ): R.func_attr({"Composite": "x.split", "Primitive": 1}) with R.dataflow(): tup: R.Tuple( @@ -978,9 +978,9 @@ def func1(x: R.Tensor((10, 10), "float32")): @I.ir_module class Expected1: @R.function(private=True) - def fused_relax_clip( - x: R.Tensor((10, 10), dtype="float32") - ) -> R.Tensor((10, 10), dtype="float32"): + def fused_relax_clip(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor( + (10, 10), dtype="float32" + ): R.func_attr({"Composite": "x.clip", "Primitive": 1}) with R.dataflow(): gv: R.Tensor((10, 10), dtype="float32") = R.clip( @@ -1014,9 +1014,9 @@ def func2(x: R.Tensor((10, 10), "float32")): @I.ir_module class Expected2: @R.function(private=True) - def fused_relax_clip( - x: R.Tensor((10, 10), dtype="float32") - ) -> R.Tensor((10, 10), dtype="float32"): + def fused_relax_clip(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor( + (10, 10), dtype="float32" + ): R.func_attr({"Composite": "x.clip", "Primitive": 1}) with R.dataflow(): gv: R.Tensor((10, 10), dtype="float32") = R.clip( @@ -1026,9 +1026,9 @@ def fused_relax_clip( return gv @R.function(private=True) - def fused_relax_clip1( - x: R.Tensor((10, 10), dtype="float32") - ) -> R.Tensor((10, 10), dtype="float32"): + def fused_relax_clip1(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor( + (10, 10), dtype="float32" + ): R.func_attr({"Composite": "x.clip", "Primitive": 1}) with R.dataflow(): gv: R.Tensor((10, 10), dtype="float32") = R.clip( @@ -1038,9 +1038,9 @@ def fused_relax_clip1( return gv @R.function - def main( - x: R.Tensor((10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")): + def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tuple( + R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32") + ): cls = Expected2 with R.dataflow(): gv: R.Tensor((10, 10), dtype="float32") = cls.fused_relax_clip(x) diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 278ac825f7a7..ecf5107ed15f 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -43,7 +43,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -124,7 +124,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -209,7 +209,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -298,7 +298,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -441,8 +441,8 @@ def main_transform_params( @T.prim_func(private=True) def slice_buffer( Input: T.Buffer((16, 16), "float32"), - slice_index: T.int64, Output: T.Buffer(16, "float32"), + slice_index: T.int64, ): for i in T.grid(16): with T.block("slice_buffer"): @@ -479,8 +479,8 @@ def main_transform_params(slice_shape_expr: R.Shape(["slice_index"])): @T.prim_func(private=True) def slice_buffer( Input: T.Buffer((16, 16), "float32"), - slice_index: T.int64, Output: T.Buffer(16, "float32"), + slice_index: T.int64, ): for i in T.grid(16): with T.block("slice_buffer"): @@ -511,7 +511,7 @@ def main_transform_params( params: R.Tuple( R.Tensor((3, "ic", 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32"), - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32") ): @@ -587,9 +587,9 @@ def copy(x: T.Buffer((), "float32"), y: T.Buffer((), "float32")): y[()] = x[()] @R.function - def main_transform_params( - params: R.Tuple(R.Tensor((), dtype="float32")) - ) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")): + def main_transform_params(params: R.Tuple(R.Tensor((), dtype="float32"))) -> R.Tuple( + R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32") + ): # we expect ToNonDataflow and RemovePurityTracking to be invoked first R.func_attr({"relax.force_pure": True}) cls = Module @@ -637,7 +637,7 @@ def transform_params( params: R.Tuple( R.Tensor((3, "ic", 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32"), - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32") ): @@ -691,7 +691,7 @@ def test_duplicate_outputs(): class Before: @R.function def main_transform_params( - params: R.Tuple(R.Tensor([16], dtype="int32"), R.Tensor([16], dtype="int32")) + params: R.Tuple(R.Tensor([16], dtype="int32"), R.Tensor([16], dtype="int32")), ): R.func_attr({"relax.force_pure": True}) param0 = params[0] @@ -966,7 +966,7 @@ def transform_params( class Expected: @R.function def transform_params( - fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object) + fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object), ) -> R.Tuple(R.Tensor(ndim=2, dtype="float32"), R.Tensor(ndim=2, dtype="float32")): R.func_attr({"num_input": 1}) m = T.int64() diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py index f7befd3b886a..64ff16cc61f5 100644 --- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py +++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py @@ -61,9 +61,9 @@ def expand_dims( expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i2_1, i4_1] @R.function - def main( - x: R.Tensor((8, 3), dtype="float32") - ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"): + def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor( + (2, 1, 4, 1, 3), dtype="float32" + ): cls = Module with R.dataflow(): y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32")) @@ -112,9 +112,9 @@ def expand_dims( expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i2_1, i4_1] @R.function - def main( - x: R.Tensor((8, 3), dtype="float32") - ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"): + def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor( + (2, 1, 4, 1, 3), dtype="float32" + ): with R.dataflow(): cls = Expected y: R.Tensor((2, 4, 3), "float32") = R.reshape(x, (2, 4, 3)) @@ -252,11 +252,15 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): ] @R.function - def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), dtype="float32"): + def main(x: R.Tensor((8, 16, 128), dtype="float16")) -> R.Tensor( + (1, 8, 16, 128), dtype="float16" + ): cls = Module with R.dataflow(): - y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32")) - z = R.add(y, R.const(1, "float32")) + y = R.call_tir( + cls.reshape, (x,), out_sinfo=R.Tensor((1, 8, 16, 128), dtype="float16") + ) + z = R.add(y, R.const(1, "float16")) R.output(z) return z @@ -290,10 +294,14 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): ] @R.function - def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), dtype="float32"): + def main(x: R.Tensor((8, 16, 128), dtype="float16")) -> R.Tensor( + (1, 8, 16, 128), dtype="float16" + ): with R.dataflow(): - y: R.Tensor((2, 4, 3), dtype="float32") = R.reshape(x, R.shape([2, 4, 3])) - z: R.Tensor((2, 4, 3), dtype="float32") = R.add(y, R.const(1, "float32")) + y: R.Tensor((1, 8, 16, 128), dtype="float16") = R.reshape( + x, R.shape([1, 8, 16, 128]) + ) + z: R.Tensor((1, 8, 16, 128), dtype="float16") = R.add(y, R.const(1, "float16")) R.output(z) return z @@ -383,7 +391,7 @@ def main( R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), - ) + ), ) -> R.Tensor((2, 4096, 8, 40), dtype="float16"): cls = Module with R.dataflow(): @@ -444,7 +452,7 @@ def main( R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), - ) + ), ) -> R.Tensor((2, 4096, 8, 40), dtype="float16"): with R.dataflow(): lv: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[0] @@ -735,7 +743,6 @@ def add( z_handle: T.handle, N: T.int64, ): - y1 = T.match_buffer(y1_handle, [N * 4, T.int64(4)], "float32") y2 = T.match_buffer(y2_handle, [N * 4, T.int64(4)], "float32") z = T.match_buffer(z_handle, [N * 4, T.int64(4)], "float32") diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 4f41b662caf2..96d162dfa1dc 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -77,7 +77,7 @@ def test_mismatch_cast_dims_and_ndim(): @R.function def f( - x: R.Tensor((2, 3), "float32", ndim=3) + x: R.Tensor((2, 3), "float32", ndim=3), ): # error: ndim and the shape dims are mismatch return x @@ -961,11 +961,11 @@ def test_call_tir_with_tir_var(): class Module: @R.function def main( - dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2", "float32")) + dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2",), "float32") ) -> R.Tensor(("n * 2",), "float32"): n = T.int64() cls = Module - y = R.call_tir(cls.copy, (x,), R.Tensor(((n * 2,)), dtype="float32"), tir_vars=(n,)) + y = R.call_tir(cls.copy, x, R.Tensor((n * 2,), dtype="float32"), tir_vars=(n,)) return y @T.prim_func @@ -1028,9 +1028,7 @@ def copy( out1[ax0, ax1] = B[ax0, ax1] @R.function - def main( - x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") - ) -> R.Tuple( + def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")) -> R.Tuple( R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32") ): res = R.call_tir_inplace( @@ -1046,13 +1044,13 @@ def main( def test_local_function(): @R.function - def main( - x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") - ) -> R.Tensor((2, 3), "float32"): + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")) -> R.Tensor( + (2, 3), "float32" + ): @R.function - def outer_func( - c1: R.Tensor((2, 3), "float32") - ) -> R.Callable((R.Tensor(None, "float32", ndim=2),), R.Tensor(None, "float32", ndim=2)): + def outer_func(c1: R.Tensor((2, 3), "float32")) -> R.Callable( + (R.Tensor(None, "float32", ndim=2),), R.Tensor(None, "float32", ndim=2) + ): @R.function def inner_func(x1: R.Tensor((2, 3), "float32")): s: R.Tensor((2, 3), "float32") = R.add(x1, c1) @@ -1487,9 +1485,9 @@ def test_erase_to_well_defined_infers_from_prim_value(): class Module: # The subroutine's symbolic variables are only in-scope for the subroutine. @R.function - def subroutine( - x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n") - ) -> R.Tensor(["m", "n"]): + def subroutine(x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n")) -> R.Tensor( + ["m", "n"] + ): q = x m, n = T.int64(), T.int64() z = R.match_cast(q, R.Tensor((m, n))) @@ -1547,9 +1545,9 @@ def test_symbolic_vars_in_tensor_shape_with_definition_first(): """Second param may use symbolic variable defined in first param""" @R.function - def bar( - x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32") - ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"): + def bar(x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32")) -> R.Tensor( + ("T.max(m, 20) + 1",), "float32" + ): m = T.int64() z = R.call_dps_packed("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32")) return z @@ -2014,9 +2012,9 @@ def test_function_with_void_return_type_in_if_else(): @I.ir_module class Unsugared: @R.function(pure=False) - def conditional( - x: R.Tensor((), "int32"), condition: R.Tensor((), "bool") - ) -> R.Tensor((), "int32"): + def conditional(x: R.Tensor((), "int32"), condition: R.Tensor((), "bool")) -> R.Tensor( + (), "int32" + ): if condition: y = R.print(x, format="True condition: {}") else: @@ -2026,9 +2024,9 @@ def conditional( @I.ir_module class Sugared: @R.function(pure=False) - def conditional( - x: R.Tensor((), "int32"), condition: R.Tensor((), "bool") - ) -> R.Tensor((), "int32"): + def conditional(x: R.Tensor((), "int32"), condition: R.Tensor((), "bool")) -> R.Tensor( + (), "int32" + ): if condition: R.print(x, format="True condition: {}") else: @@ -2135,7 +2133,9 @@ def func(z: R.Tensor((4, 4), "float32")): @R.function(private=True) def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]): alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor( - R.shape([4, 4]), R.dtype("float32"), R.prim_value(2) # Make sure prim_value is 2 + R.shape([4, 4]), + R.dtype("float32"), + R.prim_value(2), # Make sure prim_value is 2 ) shape: R.Shape([4, 4]) = R.shape_of(alloc) shape_1: R.Shape([4, 4]) = shape @@ -2167,7 +2167,9 @@ def func(z: R.Tensor((4, 4), "float32")): @R.function(private=True) def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]): alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor( - R.shape([4, 4]), R.dtype("float32"), R.prim_value(1) # Make sure prim_value is 1 + R.shape([4, 4]), + R.dtype("float32"), + R.prim_value(1), # Make sure prim_value is 1 ) shape: R.Shape([4, 4]) = R.shape_of(alloc) shape_1: R.Shape([4, 4]) = shape @@ -2336,7 +2338,6 @@ def explicit_sinfo( B: R.Tensor(["N"], "float32"), cond: R.Prim("bool"), ) -> R.Tensor(["N"], "float32"): - N = T.int64() if cond: diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 30fd06d4f14d..0c202df6bfe6 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -298,9 +298,9 @@ def inplace_add(A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32")): A[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] @R.function - def main( - x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") - ) -> R.Tensor((2, 3), "int32"): + def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")) -> R.Tensor( + (2, 3), "int32" + ): res = R.call_tir_inplace( TestCallTIRInplaceE2ERW.inplace_add, (x, y), [0], R.Tensor((2, 3), "int32") ) @@ -955,16 +955,14 @@ def test_vm_mul(x: T.handle, y: T.handle, z: T.handle): # test returning a tuple @R.function - def test_vm_tuple( - x: R.Tensor((), "int32") - ) -> R.Tuple(R.Tensor((), "int32"), R.Tensor((), "int32")): + def test_vm_tuple(x: R.Tensor((), "int32")) -> R.Tuple( + R.Tensor((), "int32"), R.Tensor((), "int32") + ): return (x, x) # nested tuple too @R.function - def test_vm_nested_tuple( - x: R.Tensor((), "int32") - ) -> R.Tuple( + def test_vm_nested_tuple(x: R.Tensor((), "int32")) -> R.Tuple( R.Tuple( R.Tensor((), "int32"), R.Tuple( @@ -988,8 +986,10 @@ class ModA: I.module_attrs({"system_lib_prefix": "libA_"}) @T.prim_func - def tir_init(x: T.Buffer((2), "float32")) -> None: - for i in range(2): + def tir_init(x_handle: T.handle): + N = T.int64() + x = T.match_buffer(x_handle, [N], "float32") + for i in range(N): x[i] = T.float32(0) @R.function @@ -1003,8 +1003,10 @@ class ModB: I.module_attrs({"system_lib_prefix": "libB_"}) @T.prim_func - def tir_init(x: T.Buffer((2), "float32")) -> None: - for i in range(2): + def tir_init(x_handle: T.handle): + N = T.int64() + x = T.match_buffer(x_handle, [N], "float32") + for i in range(N): x[i] = T.float32(1) @R.function From e44e11785f841aaac2a0864e1f80303bf915fedf Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 22 Aug 2024 10:52:54 -0500 Subject: [PATCH 4/6] lint fixes --- tests/python/relax/test_dataflow_inplace.py | 6 +- tests/python/relax/test_frontend_dynamo.py | 34 ++++++----- .../test_transform_dead_code_elimination.py | 30 +++++----- tests/python/relax/test_transform_fuse_ops.py | 6 +- .../test_transform_fuse_ops_by_pattern.py | 60 +++++++++---------- .../test_transform_lazy_transform_params.py | 6 +- ...test_transform_rewrite_dataflow_reshape.py | 24 ++++---- tests/python/relax/test_tvmscript_parser.py | 40 +++++++------ tests/python/relax/test_vm_build.py | 16 ++--- 9 files changed, 114 insertions(+), 108 deletions(-) diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index a127b0fa263f..cd6e285de499 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -323,9 +323,9 @@ def test_inplace_simple_case(): @I.ir_module class InplaceBasic: @R.function - def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")) -> R.Tensor( - (2, 3), "int32" - ): + def main( + x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") + ) -> R.Tensor((2, 3), "int32"): with R.dataflow(): z = R.add(x, y) # cannot be done inplace: x and y are live later p = R.add(z, z) # can be done inplace: z is not used later diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index ed9f628aea26..21e1d82d28b5 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -314,9 +314,9 @@ def forward(self, input): @I.ir_module class Expected1: @R.function - def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor( - (10, 10), dtype="float32" - ): + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.full( R.shape([10, 10]), R.const(1, "float32"), dtype="float32" @@ -345,9 +345,9 @@ def forward(self, input): @I.ir_module class Expected1: @R.function - def main(inp_0: R.Tensor((256, 256), dtype="float32")) -> R.Tensor( - (10, 10), dtype="float32" - ): + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): with R.dataflow(): lv: R.Tensor((10, 10), dtype="float32") = R.full( R.shape([10, 10]), R.const(1, "float32"), dtype="float32" @@ -380,9 +380,9 @@ def forward(self, input): @I.ir_module class ExpectedGeLU: @R.function - def main(inp_0: R.Tensor((128, 256), dtype="float32")) -> R.Tensor( - (128, 256), dtype="float32" - ): + def main( + inp_0: R.Tensor((128, 256), dtype="float32") + ) -> R.Tensor((128, 256), dtype="float32"): with R.dataflow(): lv: R.Tensor((128, 256), dtype="float32") = R.nn.gelu(inp_0) gv: R.Tensor((128, 256), dtype="float32") = lv @@ -392,9 +392,9 @@ def main(inp_0: R.Tensor((128, 256), dtype="float32")) -> R.Tensor( @I.ir_module class ExpectedGeLUTanh: @R.function - def main(inp_0: R.Tensor((128, 256), dtype="float32")) -> R.Tensor( - (128, 256), dtype="float32" - ): + def main( + inp_0: R.Tensor((128, 256), dtype="float32") + ) -> R.Tensor((128, 256), dtype="float32"): with R.dataflow(): lv: R.Tensor((128, 256), dtype="float32") = R.nn.gelu_tanh(inp_0) gv: R.Tensor((128, 256), dtype="float32") = lv @@ -489,9 +489,9 @@ def main( @I.ir_module class Expected2: @R.function - def main(inp_0: R.Tensor((1, 77, 1280), dtype="float32")) -> R.Tensor( - (1, 77, 1280), dtype="float32" - ): + def main( + inp_0: R.Tensor((1, 77, 1280), dtype="float32") + ) -> R.Tensor((1, 77, 1280), dtype="float32"): with R.dataflow(): lv: R.Tensor((1,), dtype="int64") = R.arange( R.prim_value(0), R.prim_value(1), R.prim_value(1), dtype="int64" @@ -514,7 +514,9 @@ def main(inp_0: R.Tensor((1, 77, 1280), dtype="float32")) -> R.Tensor( class Select2(Module): def forward(self, input1): - result = input1[torch.arange(1),] + result = input1[ + torch.arange(1), + ] return result verify_dynamo_model( diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 5e3d50b9e1f5..fb915262cc81 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -176,9 +176,9 @@ def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32 return gv0 @R.function - def main(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")) -> R.Tensor( - (16, 16), "float32" - ): + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((16, 16), dtype="float32")) return gv0 @@ -213,9 +213,9 @@ def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32 return gv0 @R.function - def foo(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")) -> R.Tensor( - (16, 16), "float32" - ): + def foo( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((16, 16), dtype="float32")) return gv0 @@ -254,9 +254,9 @@ def unused_func(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32 return gv0 @R.function - def foo(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")) -> R.Tensor( - (16, 16), "float32" - ): + def foo( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((16, 16), dtype="float32")) return gv0 @@ -339,9 +339,9 @@ def relax_add(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") return gv0 @R.function - def main(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")) -> R.Tensor( - (16, 16), "float32" - ): + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): gv0 = InputModule.relax_add(x, w) return gv0 @@ -375,9 +375,9 @@ def unused_func2(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float3 return gv0 @R.function - def main(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")) -> R.Tensor( - (16, 16), "float32" - ): + def main( + x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32") + ) -> R.Tensor((16, 16), "float32"): gv0 = R.add(x, w) return gv0 diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index f751e1b1a7df..9ad66bec012a 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -1633,9 +1633,9 @@ def main( ) -> R.Tensor((10, 20), dtype="float32"): cls = Expected with R.dataflow(): - gv1: R.Tensor((10, 20), dtype="float32") = ( - cls.fused_add_exp_inplace_squeeze_inplace(x, p0) - ) + gv1: R.Tensor( + (10, 20), dtype="float32" + ) = cls.fused_add_exp_inplace_squeeze_inplace(x, p0) R.output(gv1) return gv1 diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 0798ca462bf9..a07875fcdae6 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -56,9 +56,9 @@ def main( ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): cls = Conv2dReLU_composite_annotated with R.dataflow(): - gv: R.Tensor((1, 64, 56, 56), dtype="float32") = ( - cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1) - ) + gv: R.Tensor( + (1, 64, 56, 56), dtype="float32" + ) = cls.fused_relax_nn_conv2d_relax_nn_relu_dnnl(data, weight1) R.output(gv) return gv @@ -120,12 +120,12 @@ def main( ) -> R.Tensor((1, 64, 54, 54), dtype="float32"): cls = Conv2dReLUx2Partitioned with R.dataflow(): - lv: R.Tensor((1, 64, 56, 56), dtype="float32") = ( - cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1) - ) - gv: R.Tensor((1, 64, 54, 54), dtype="float32") = ( - cls.fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2) - ) + lv: R.Tensor( + (1, 64, 56, 56), dtype="float32" + ) = cls.fused_relax_nn_conv2d_relax_nn_relu(data, weight1) + gv: R.Tensor( + (1, 64, 54, 54), dtype="float32" + ) = cls.fused_relax_nn_conv2d_relax_nn_relu1(lv, weight2) R.output(gv) return gv @@ -235,9 +235,9 @@ def main( lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d( data, weight1 ) - gv: R.Tensor((1, 64, 54, 54), dtype="float32") = ( - cls.fused_relax_nn_conv2d_relax_nn_relu(lv, weight2) - ) + gv: R.Tensor( + (1, 64, 54, 54), dtype="float32" + ) = cls.fused_relax_nn_conv2d_relax_nn_relu(lv, weight2) R.output(gv) return gv @@ -903,9 +903,9 @@ def func(inp: R.Tensor((16, 32), "float32")): @tvm.script.ir_module class Expected1: @R.function(private=True) - def fused_relax_split(inp: R.Tensor((16, 32), dtype="float32")) -> R.Tuple( - R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), dtype="float32") - ): + def fused_relax_split( + inp: R.Tensor((16, 32), dtype="float32") + ) -> R.Tuple(R.Tensor((16, 16), dtype="float32"), R.Tensor((16, 16), dtype="float32")): R.func_attr({"Composite": "x.split", "Primitive": 1}) with R.dataflow(): gv: R.Tuple( @@ -932,9 +932,9 @@ def main(inp: R.Tensor((16, 32), dtype="float32")) -> R.Tensor((16, 16), dtype=" @I.ir_module class Expected2: @R.function(private=True) - def fused_relax_split_relax_add(inp: R.Tensor((16, 32), dtype="float32")) -> R.Tensor( - (16, 16), dtype="float32" - ): + def fused_relax_split_relax_add( + inp: R.Tensor((16, 32), dtype="float32") + ) -> R.Tensor((16, 16), dtype="float32"): R.func_attr({"Composite": "x.split", "Primitive": 1}) with R.dataflow(): tup: R.Tuple( @@ -978,9 +978,9 @@ def func1(x: R.Tensor((10, 10), "float32")): @I.ir_module class Expected1: @R.function(private=True) - def fused_relax_clip(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor( - (10, 10), dtype="float32" - ): + def fused_relax_clip( + x: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): R.func_attr({"Composite": "x.clip", "Primitive": 1}) with R.dataflow(): gv: R.Tensor((10, 10), dtype="float32") = R.clip( @@ -1014,9 +1014,9 @@ def func2(x: R.Tensor((10, 10), "float32")): @I.ir_module class Expected2: @R.function(private=True) - def fused_relax_clip(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor( - (10, 10), dtype="float32" - ): + def fused_relax_clip( + x: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): R.func_attr({"Composite": "x.clip", "Primitive": 1}) with R.dataflow(): gv: R.Tensor((10, 10), dtype="float32") = R.clip( @@ -1026,9 +1026,9 @@ def fused_relax_clip(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor( return gv @R.function(private=True) - def fused_relax_clip1(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor( - (10, 10), dtype="float32" - ): + def fused_relax_clip1( + x: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): R.func_attr({"Composite": "x.clip", "Primitive": 1}) with R.dataflow(): gv: R.Tensor((10, 10), dtype="float32") = R.clip( @@ -1038,9 +1038,9 @@ def fused_relax_clip1(x: R.Tensor((10, 10), dtype="float32")) -> R.Tensor( return gv @R.function - def main(x: R.Tensor((10, 10), dtype="float32")) -> R.Tuple( - R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32") - ): + def main( + x: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")): cls = Expected2 with R.dataflow(): gv: R.Tensor((10, 10), dtype="float32") = cls.fused_relax_clip(x) diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index ecf5107ed15f..87a5698f1bf8 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -587,9 +587,9 @@ def copy(x: T.Buffer((), "float32"), y: T.Buffer((), "float32")): y[()] = x[()] @R.function - def main_transform_params(params: R.Tuple(R.Tensor((), dtype="float32"))) -> R.Tuple( - R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32") - ): + def main_transform_params( + params: R.Tuple(R.Tensor((), dtype="float32")) + ) -> R.Tuple(R.Tensor((), dtype="float32"), R.Tensor((), dtype="float32")): # we expect ToNonDataflow and RemovePurityTracking to be invoked first R.func_attr({"relax.force_pure": True}) cls = Module diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py index 64ff16cc61f5..5a7d76d8fe41 100644 --- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py +++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py @@ -61,9 +61,9 @@ def expand_dims( expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i2_1, i4_1] @R.function - def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor( - (2, 1, 4, 1, 3), dtype="float32" - ): + def main( + x: R.Tensor((8, 3), dtype="float32") + ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"): cls = Module with R.dataflow(): y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32")) @@ -112,9 +112,9 @@ def expand_dims( expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i2_1, i4_1] @R.function - def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor( - (2, 1, 4, 1, 3), dtype="float32" - ): + def main( + x: R.Tensor((8, 3), dtype="float32") + ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"): with R.dataflow(): cls = Expected y: R.Tensor((2, 4, 3), "float32") = R.reshape(x, (2, 4, 3)) @@ -252,9 +252,9 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): ] @R.function - def main(x: R.Tensor((8, 16, 128), dtype="float16")) -> R.Tensor( - (1, 8, 16, 128), dtype="float16" - ): + def main( + x: R.Tensor((8, 16, 128), dtype="float16") + ) -> R.Tensor((1, 8, 16, 128), dtype="float16"): cls = Module with R.dataflow(): y = R.call_tir( @@ -294,9 +294,9 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): ] @R.function - def main(x: R.Tensor((8, 16, 128), dtype="float16")) -> R.Tensor( - (1, 8, 16, 128), dtype="float16" - ): + def main( + x: R.Tensor((8, 16, 128), dtype="float16") + ) -> R.Tensor((1, 8, 16, 128), dtype="float16"): with R.dataflow(): y: R.Tensor((1, 8, 16, 128), dtype="float16") = R.reshape( x, R.shape([1, 8, 16, 128]) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 96d162dfa1dc..deb751738946 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1028,7 +1028,9 @@ def copy( out1[ax0, ax1] = B[ax0, ax1] @R.function - def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")) -> R.Tuple( + def main( + x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") + ) -> R.Tuple( R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32"), R.Tensor((2, 3), "int32") ): res = R.call_tir_inplace( @@ -1044,13 +1046,13 @@ def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")) -> R.Tuple( def test_local_function(): @R.function - def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32")) -> R.Tensor( - (2, 3), "float32" - ): + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): @R.function - def outer_func(c1: R.Tensor((2, 3), "float32")) -> R.Callable( - (R.Tensor(None, "float32", ndim=2),), R.Tensor(None, "float32", ndim=2) - ): + def outer_func( + c1: R.Tensor((2, 3), "float32") + ) -> R.Callable((R.Tensor(None, "float32", ndim=2),), R.Tensor(None, "float32", ndim=2)): @R.function def inner_func(x1: R.Tensor((2, 3), "float32")): s: R.Tensor((2, 3), "float32") = R.add(x1, c1) @@ -1485,9 +1487,9 @@ def test_erase_to_well_defined_infers_from_prim_value(): class Module: # The subroutine's symbolic variables are only in-scope for the subroutine. @R.function - def subroutine(x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n")) -> R.Tensor( - ["m", "n"] - ): + def subroutine( + x: R.Tensor, _m: R.Prim(value="m"), _n: R.Prim(value="n") + ) -> R.Tensor(["m", "n"]): q = x m, n = T.int64(), T.int64() z = R.match_cast(q, R.Tensor((m, n))) @@ -1545,9 +1547,9 @@ def test_symbolic_vars_in_tensor_shape_with_definition_first(): """Second param may use symbolic variable defined in first param""" @R.function - def bar(x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32")) -> R.Tensor( - ("T.max(m, 20) + 1",), "float32" - ): + def bar( + x: R.Tensor(("m",), "float32"), y: R.Tensor(("T.max(m, 20)",), "float32") + ) -> R.Tensor(("T.max(m, 20) + 1",), "float32"): m = T.int64() z = R.call_dps_packed("test_intrin", (x, y), R.Tensor((T.max(m, 20) + 1,), dtype="float32")) return z @@ -2012,9 +2014,9 @@ def test_function_with_void_return_type_in_if_else(): @I.ir_module class Unsugared: @R.function(pure=False) - def conditional(x: R.Tensor((), "int32"), condition: R.Tensor((), "bool")) -> R.Tensor( - (), "int32" - ): + def conditional( + x: R.Tensor((), "int32"), condition: R.Tensor((), "bool") + ) -> R.Tensor((), "int32"): if condition: y = R.print(x, format="True condition: {}") else: @@ -2024,9 +2026,9 @@ def conditional(x: R.Tensor((), "int32"), condition: R.Tensor((), "bool")) -> R. @I.ir_module class Sugared: @R.function(pure=False) - def conditional(x: R.Tensor((), "int32"), condition: R.Tensor((), "bool")) -> R.Tensor( - (), "int32" - ): + def conditional( + x: R.Tensor((), "int32"), condition: R.Tensor((), "bool") + ) -> R.Tensor((), "int32"): if condition: R.print(x, format="True condition: {}") else: diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 0c202df6bfe6..ecf33aa9da1e 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -298,9 +298,9 @@ def inplace_add(A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32")): A[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1] @R.function - def main(x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32")) -> R.Tensor( - (2, 3), "int32" - ): + def main( + x: R.Tensor((2, 3), "int32"), y: R.Tensor((2, 3), "int32") + ) -> R.Tensor((2, 3), "int32"): res = R.call_tir_inplace( TestCallTIRInplaceE2ERW.inplace_add, (x, y), [0], R.Tensor((2, 3), "int32") ) @@ -955,14 +955,16 @@ def test_vm_mul(x: T.handle, y: T.handle, z: T.handle): # test returning a tuple @R.function - def test_vm_tuple(x: R.Tensor((), "int32")) -> R.Tuple( - R.Tensor((), "int32"), R.Tensor((), "int32") - ): + def test_vm_tuple( + x: R.Tensor((), "int32") + ) -> R.Tuple(R.Tensor((), "int32"), R.Tensor((), "int32")): return (x, x) # nested tuple too @R.function - def test_vm_nested_tuple(x: R.Tensor((), "int32")) -> R.Tuple( + def test_vm_nested_tuple( + x: R.Tensor((), "int32") + ) -> R.Tuple( R.Tuple( R.Tensor((), "int32"), R.Tuple( From 071eea124050c2d387d9a18ae2acb9b45a8c4b34 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 29 Aug 2024 08:56:04 -0500 Subject: [PATCH 5/6] Restrict the R.call_tir validation to well-formed checker Initial implementation performed the validation as part of `FNormalize`, to maximize coverage of this check. This increased end-to-end compilation time by ~10%, and so the check was requested to be restricted to the well-formed checker. Expensive operator-specific validation is now performed in the new `FValidate` attribute. --- include/tvm/relax/op_attr_types.h | 27 ++++++++++++ src/relax/analysis/well_formed.cc | 11 +++++ src/relax/op/op.cc | 68 +++++++++++++++++++------------ 3 files changed, 81 insertions(+), 25 deletions(-) diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index 291bee597c03..0ddc2baefbef 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -56,6 +56,14 @@ using FCallPacked = String; * expressed in multiple syntactically valid and semantically * equivalent forms, to normalize to a single representation. * + * Note: `FNormalize` is applied for each expression as part of the + * `relax::BlockBuilder`. While operator-specific validation may + * be performed within the `FNormalize` implementation, ensuring + * that errors are caught as early as possible, this should only be + * used when validation is fast to apply. If the validation logic + * may be slow, it should instead be implemented in `FValidate`, + * which is only run as part of the well-formed checker. + * * \param bb The BlockBuilder context. * * \param call The call to be normalized. It is provided by-value, to @@ -63,6 +71,25 @@ using FCallPacked = String; */ using FNormalize = runtime::TypedPackedFunc; +/*! + * \brief The function type of a validation function. + * + * A validation function is used to define constraints that should be + * verified for an operator as part of the well-formed checker. + * + * Note: `FValidate` is only applied as part of the well-formed + * checker. While this minimizes overhead while compiling Relax, + * this delay between generating an ill-formed `relax::Call` and + * identifying the ill-formed call may complicate debugging. If + * the validation logic is very fast to check, and doing so would + * not introduce a signficant overhead, consider validating as part + * of `FNormalize`, which is applied by the block builder for each + * `relax::Call`. + * + * \param call The call to be validated. + */ +using FValidate = runtime::TypedPackedFunc; + /*! \brief The function type of a legalization function. * * A legalization function is used to replace a `relax::Call` with diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 626fadda273d..235059ece2aa 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -352,6 +352,16 @@ class WellFormedChecker : public relax::ExprVisitor, << after_normalize); } } + + if (auto func_validate = op_map_validate_.get(call->op, nullptr); func_validate != nullptr) { + try { + func_validate(GetRef(call)); + } catch (std::exception& err) { + Malformed(Diagnostic::Error(call) << "Operator-specific validation (FValidate) for " + << call->op << " identified error: \n" + << err.what()); + } + } } void VisitExpr_(const IfNode* op) final { @@ -574,6 +584,7 @@ class WellFormedChecker : public relax::ExprVisitor, std::unordered_map symbolic_var_func_map_; tvm::OpAttrMap op_map_normalize_ = Op::GetAttrMap("FNormalize"); + tvm::OpAttrMap op_map_validate_ = Op::GetAttrMap("FValidate"); }; bool WellFormed(Variant obj, bool check_struct_info) { diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 2e42e1a06167..3e0f0eba313a 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -463,41 +463,18 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { << ", which is neither an in-line tuple, " << "nor a variable binding that may be normalized to an in-line tuple."; - auto packed_int_sinfo = [&]() -> Optional { - if (call->args.size() <= 2) { - return NullOpt; - } - + if (call->args.size() > 2) { Expr packed_ints = call->args[2]; CHECK(packed_ints->struct_info_.as()) << "Operation " << call->op << " expects the optional third argument, " << "if present, to be a ShapeTuple. " << "However, the third argument " << packed_ints << " has struct info " << packed_ints->struct_info_; - return GetStructInfo(packed_ints); - }(); - - auto opt_inplace_indices = [&]() -> Optional> { - if (const auto* attrs = call->attrs.as()) { - return attrs->inplace_indices; - } else { - return NullOpt; - } - }(); + } CHECK_EQ(call->sinfo_args.size(), 1) << "R.call_tir should have exactly one `sinfo_args` parameter, " << "which defines the output of the PrimFunc."; - StructInfo explicit_sinfo = call->sinfo_args[0]; - auto inferred_sinfo = InferCallTIROutputStructInfoFromArguments( - GetStructInfo(callee), GetStructInfo(arg_tuple), packed_int_sinfo, opt_inplace_indices); - if (inferred_sinfo.defined()) { - CHECK(IsBaseOf(inferred_sinfo.value(), explicit_sinfo)) - << "TypeError: " - << "The `out_sinfo` argument for R.call_tir must be compatible with the PrimFunc. " - << "However, the PrimFunc's signature implies that the output should be " << inferred_sinfo - << ", but the `out_sinfo` argument was " << explicit_sinfo; - } auto unwrap_binding = [&ctx](Expr expr) -> Optional { if (auto var = expr.as()) { @@ -547,6 +524,44 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { return std::move(call); } +void ValidateCallTIR(Call call) { + // This function is used for validation of `relax.call_tir`, + // along with the variants `relax.call_tir_with_grad` and + // `relax.call_tir_inplace`. Therefore, all error messages should + // be written in terms of `call->op`, and should not explicitly + // reference the `relax.call_tir` operator.` + + auto callee = call->args[0]; + Expr arg_tuple = call->args[1]; + + auto packed_int_sinfo = [&]() -> Optional { + if (call->args.size() <= 2) { + return NullOpt; + } else { + return GetStructInfo(call->args[2]); + } + }(); + + auto opt_inplace_indices = [&]() -> Optional> { + if (const auto* attrs = call->attrs.as()) { + return attrs->inplace_indices; + } else { + return NullOpt; + } + }(); + + StructInfo explicit_sinfo = call->sinfo_args[0]; + auto inferred_sinfo = InferCallTIROutputStructInfoFromArguments( + GetStructInfo(callee), GetStructInfo(arg_tuple), packed_int_sinfo, opt_inplace_indices); + if (inferred_sinfo.defined()) { + CHECK(IsBaseOf(inferred_sinfo.value(), explicit_sinfo)) + << "TypeError: " + << "The `out_sinfo` argument for R.call_tir must be compatible with the PrimFunc. " + << "However, the PrimFunc's signature implies that the output should be " << inferred_sinfo + << ", but the `out_sinfo` argument was " << explicit_sinfo; + } +} + RELAY_REGISTER_OP("relax.call_tir") .set_num_inputs(3) .add_argument("func", "Expr", "The destination-passing-style function.") @@ -556,6 +571,7 @@ RELAY_REGISTER_OP("relax.call_tir") "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIR) + .set_attr("FValidate", ValidateCallTIR) .set_attr("FPurity", Bool(true)); Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, @@ -601,6 +617,7 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad") "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIR) + .set_attr("FValidate", ValidateCallTIR) .set_attr("FPurity", Bool(true)); Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinfo_list, @@ -741,6 +758,7 @@ RELAY_REGISTER_OP("relax.call_tir_inplace") "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIRInPlace) + .set_attr("FValidate", ValidateCallTIR) // Warning: considered pure, but it has the potential to create visible effects! // This should only be used if it has been *checked* that it is safe (no aliases, in-place // arguments will no longer be live) From 0d81a926b1f4bfddefcda78add298368bbbc40bb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 5 Sep 2024 13:57:36 -0500 Subject: [PATCH 6/6] ci bump