diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index 291bee597c03..0ddc2baefbef 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -56,6 +56,14 @@ using FCallPacked = String; * expressed in multiple syntactically valid and semantically * equivalent forms, to normalize to a single representation. * + * Note: `FNormalize` is applied for each expression as part of the + * `relax::BlockBuilder`. While operator-specific validation may + * be performed within the `FNormalize` implementation, ensuring + * that errors are caught as early as possible, this should only be + * used when validation is fast to apply. If the validation logic + * may be slow, it should instead be implemented in `FValidate`, + * which is only run as part of the well-formed checker. + * * \param bb The BlockBuilder context. * * \param call The call to be normalized. It is provided by-value, to @@ -63,6 +71,25 @@ using FCallPacked = String; */ using FNormalize = runtime::TypedPackedFunc; +/*! + * \brief The function type of a validation function. + * + * A validation function is used to define constraints that should be + * verified for an operator as part of the well-formed checker. + * + * Note: `FValidate` is only applied as part of the well-formed + * checker. While this minimizes overhead while compiling Relax, + * this delay between generating an ill-formed `relax::Call` and + * identifying the ill-formed call may complicate debugging. If + * the validation logic is very fast to check, and doing so would + * not introduce a signficant overhead, consider validating as part + * of `FNormalize`, which is applied by the block builder for each + * `relax::Call`. + * + * \param call The call to be validated. + */ +using FValidate = runtime::TypedPackedFunc; + /*! \brief The function type of a legalization function. * * A legalization function is used to replace a `relax::Call` with diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 626fadda273d..235059ece2aa 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -352,6 +352,16 @@ class WellFormedChecker : public relax::ExprVisitor, << after_normalize); } } + + if (auto func_validate = op_map_validate_.get(call->op, nullptr); func_validate != nullptr) { + try { + func_validate(GetRef(call)); + } catch (std::exception& err) { + Malformed(Diagnostic::Error(call) << "Operator-specific validation (FValidate) for " + << call->op << " identified error: \n" + << err.what()); + } + } } void VisitExpr_(const IfNode* op) final { @@ -574,6 +584,7 @@ class WellFormedChecker : public relax::ExprVisitor, std::unordered_map symbolic_var_func_map_; tvm::OpAttrMap op_map_normalize_ = Op::GetAttrMap("FNormalize"); + tvm::OpAttrMap op_map_validate_ = Op::GetAttrMap("FValidate"); }; bool WellFormed(Variant obj, bool check_struct_info) { diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 0a840248ffe8..3e0f0eba313a 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include #include #include @@ -242,15 +243,195 @@ TVM_REGISTER_GLOBAL("relax.op.call_inplace_packed").set_body_typed(MakeCallInpla // call_tir +/* If possible, infer a legal value of `arg_sinfo` + * + * The `R.call_tir` operator and its variants accept an `arg_sinfo` + * parameter, which specifies the shape of the tensor or tensors + * returned by a PrimFunc. This output shape must be compatible with + * the shape defined by the PrimFunc's signature. + * + * For dynamic shapes, it is not always possible to infer the output + * of a TIR PrimFunc from its inputs. For example, a PrimFunc that + * accepts input buffer `T.Buffer([16], "float32")` and output buffer + * `T.Buffer([M, N], "float32")` infers the values of `M` and `N` from + * the shape of the provided output buffer. + * + * If the arguments provided are not compatible with the PrimFunc's + * signature, an error will be raised. If the arguments are + * compatible with the PrimFunc's signature, but are not sufficient to + * determine the output's StructInfo, then `NullOpt` will be returned. + * + * \param func_sinfo The StructInfo of the TIR callee. + * \param arg_sinfo The StructInfo of the argument tuple. + * \param packed_ints_sinfo The StructInfo of the ShapeTuple argument, + * if present. + * \param opt_inplace_indices For `R.call_tir_inplace`, an array of + * indices indicating which outputs are constructed from in-place + * mutation of the inputs. See + * `CallTIRInplaceAttrs::inplace_indices` for more details. + * + * \return The `arg_sinfo`, if it can be inferred from the arguments. + * Otherwise, NullOpt. + */ +static Optional InferCallTIROutputStructInfoFromArguments( + StructInfo func_sinfo, StructInfo arg_sinfo, Optional packed_ints_sinfo, + Optional> opt_inplace_indices) { + auto opt_callee_sinfo = func_sinfo.as(); + CHECK(opt_callee_sinfo) << "TypeError: " + << "The first argument to `R.call_tir` must be a function, " + << "but instead received argument of type " << func_sinfo; + auto callee_sinfo = opt_callee_sinfo.value(); + + CHECK(callee_sinfo->params.defined()) + << "ValueError: " + << "The first argument to `R.call_tir` must be a function " + << "with known argument types. " + << "However, the first argument was of type " << callee_sinfo; + auto callee_params = callee_sinfo->params.value(); + + const TupleStructInfoNode* args = arg_sinfo.as(); + CHECK(args) << "TypeError: " + << "The second argument to `R.call_tir` must be a tuple, " + << "but instead received expression of type " << arg_sinfo; + + // R.call_tir expects the PrimFunc to have three groups of arguments. + // + // 1. Input arguments that are explicitly provided as Relax arguments. + // 2. Output tensor arguments. + // 3. Shape arguments, represented as `T.int64` in the PrimFunc, and + // as an optional ShapeExpr argument in the `relax::Call` node. + // + // In order to determine the return type of `R.call_tir`, we must + // identify the PrimFunc arguments that will be in group (2). + size_t num_input_arguments = args->fields.size(); + size_t num_trailing_int_arguments = 0; + const ShapeStructInfoNode* packed_tuple_sinfo = nullptr; + if (packed_ints_sinfo) { + auto packed_sinfo = packed_ints_sinfo.value(); + packed_tuple_sinfo = packed_sinfo.as(); + CHECK(packed_tuple_sinfo && !packed_tuple_sinfo->IsUnknownNdim()) + << "TypeError: " + << "The third argument to `R.call_tir`, if present, " + << "must be a ShapeTuple with known dimensionality. " + << "However, the argument received was of type " << packed_sinfo; + num_trailing_int_arguments = packed_tuple_sinfo->ndim; + } else { + num_trailing_int_arguments = 0; + } + + CHECK_LE(num_input_arguments + num_trailing_int_arguments, callee_params.size()) + << "ValueError: " + << "R.call_tir attempted to call a function using " << num_input_arguments + << " input arguments and " << num_trailing_int_arguments << " trailing integer arguments. " + << "However, the callee only accepts " << callee_params.size() << " arguments in total."; + + // While Relax can specify a distributed tensor, TIR cannot. The + // current implementation does not support determining the output + // shape for `R.dist.call_tir` calls, as it depends on the lowering + // of DistIR into regular Relax. + std::function contains_dtensor = [&contains_dtensor](StructInfo sinfo) -> bool { + if (sinfo.as()) { + return true; + } else if (auto tuple = sinfo.as()) { + return std::any_of(tuple->fields.begin(), tuple->fields.end(), contains_dtensor); + } else { + return false; + } + }; + if (contains_dtensor(arg_sinfo)) { + return NullOpt; + } + + // At this point, the return types are known. However, the shapes + // in `callee_params` may contain dynamic shape parameters that are + // not present in the caller's scope. The `DeriveCallRetStructInfo` + // utility can infer the value of dynamic parameters in + // `FuncStructInfoNode::ret` based on definitions in + // `FuncStructInfoNode::params`, inferring the correct values in the + // caller's scope. + // + // Since the callee of `R.call_tir` is provided with output + // arguments, where `DeriveCallRetStructInfo` requires a callee that + // produces its own outputs, a dummy function signature and + // arguments are used. + + auto dummy_callee_sinfo = [&]() -> FuncStructInfo { + Array dummy_params(callee_params.begin(), + callee_params.begin() + num_input_arguments); + + for (size_t i = callee_params.size() - num_trailing_int_arguments; i < callee_params.size(); + i++) { + dummy_params.push_back(callee_params[i]); + } + + Array dummy_ret(callee_params.begin() + num_input_arguments, + callee_params.end() - num_trailing_int_arguments); + + if (opt_inplace_indices) { + // For R.call_tir_inplace, the `inplace_indices` are used to + // indicate which elements of the `out_sinfo` will be generated + // as in-place mutation from an input. For any in-place + // mutation, the parameter's StructInfo must be inserted into + // `out_sinfo`. + auto inplace_indices = opt_inplace_indices.value(); + for (size_t i = 0; i < inplace_indices.size(); i++) { + auto inplace_input_index = inplace_indices[i]->value; + if (inplace_input_index >= 0) { + dummy_ret.insert(dummy_ret.begin() + i, callee_params[inplace_input_index]); + } + } + } + + auto dummy_out_sinfo = [&]() -> StructInfo { + if (dummy_ret.size() == 1) { + return dummy_ret[0]; + } else { + return TupleStructInfo(dummy_ret); + } + }(); + + return FuncStructInfo(dummy_params, dummy_out_sinfo); + }(); + + auto dummy_args = [&]() -> Array { + Array dummy_args = args->fields.Map( + [](const StructInfo& sinfo) -> Expr { return Var("dummy_leading_arg", sinfo); }); + + for (size_t i = 0; i < num_trailing_int_arguments; i++) { + ICHECK(packed_tuple_sinfo); + PrimStructInfo dummy_arg_sinfo = [&]() { + if (packed_tuple_sinfo->values) { + return PrimStructInfo(packed_tuple_sinfo->values.value()[i]); + } else { + return PrimStructInfo(DataType::Int(64)); + } + }(); + dummy_args.push_back(Var("dummy_trailing_arg", dummy_arg_sinfo)); + } + + return dummy_args; + }(); + + auto derived_ret_sinfo = DeriveCallRetStructInfo( + dummy_callee_sinfo, Call(Var("dummy_callee", dummy_callee_sinfo), dummy_args), + BlockBuilder::Create(NullOpt)); + + return derived_ret_sinfo; +} + StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { if (call->sinfo_args.size() != 1) { ctx->ReportFatal(Diagnostic::Error(call) << "sinfo_args should have exactly 1 output struct info."); } CHECK(call->args[0]->IsInstance()) - << "call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " - << "However, gets " << call->args[0]; - return call->sinfo_args[0]; + << "R.call_tir expects the first argument to be a GlobalVar referring to a TIR PrimFunc. " + << "However, the argument " << call->args[0] << " instead has type " + << call->args[0]->GetTypeKey(); + + StructInfo explicit_sinfo = call->sinfo_args[0]; + + return explicit_sinfo; } Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { @@ -264,23 +445,37 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { << "or three arguments [callee, arg_tuple, tir_args], " << "but " << call << " has " << call->args.size() << " arguments."; - Expr arg_expr = call->args[1]; + auto callee = call->args[0]; + CHECK(callee->struct_info_.as()) + << "Operation " << call->op << " expects the first argument to be a TIR callee. " + << "However, the first argument " << callee << " has struct info " << callee->struct_info_; - CHECK(arg_expr->struct_info_.as()) - << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " - << "However, the second argument " << arg_expr << " has struct info " - << arg_expr->struct_info_ << "."; + Expr arg_tuple = call->args[1]; - if (arg_expr.as()) { - return std::move(call); - } + CHECK(arg_tuple->struct_info_.as()) + << "Operation " << call->op << " expects the second argument to be a tuple of relax Expr. " + << "However, the second argument " << arg_tuple << " has struct info " + << arg_tuple->struct_info_ << "."; - CHECK(arg_expr.as()) + CHECK(arg_tuple.as() || arg_tuple.as()) << "Operation " << call->op << " must hold its arguments as an in-line tuple. " - << "However, " << call << " has arguments " << arg_expr + << "However, " << call << " has arguments " << arg_tuple << ", which is neither an in-line tuple, " << "nor a variable binding that may be normalized to an in-line tuple."; + if (call->args.size() > 2) { + Expr packed_ints = call->args[2]; + CHECK(packed_ints->struct_info_.as()) + << "Operation " << call->op << " expects the optional third argument, " + << "if present, to be a ShapeTuple. " + << "However, the third argument " << packed_ints << " has struct info " + << packed_ints->struct_info_; + } + + CHECK_EQ(call->sinfo_args.size(), 1) + << "R.call_tir should have exactly one `sinfo_args` parameter, " + << "which defines the output of the PrimFunc."; + auto unwrap_binding = [&ctx](Expr expr) -> Optional { if (auto var = expr.as()) { if (auto bound_value = ctx->LookupBinding(var.value())) { @@ -290,14 +485,21 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { return NullOpt; }; - while (auto unwrapped = unwrap_binding(arg_expr)) { - arg_expr = unwrapped.value(); - } + Tuple new_arg_tuple = [&]() { + // No replacement required. The argument tuple is already + // provided as an in-line tuple. + if (auto opt = arg_tuple.as()) { + return opt.value(); + } + + Expr unwrapped_tuple = arg_tuple; + while (auto unwrapped = unwrap_binding(unwrapped_tuple)) { + unwrapped_tuple = unwrapped.value(); + } - Tuple new_arg_expr = [&]() { // Preferred replacement. The argument tuple is provided as a // variable, but we know the value bound to that variable. - if (auto opt = arg_expr.as()) { + if (auto opt = unwrapped_tuple.as()) { return opt.value(); } @@ -306,20 +508,60 @@ Expr NormalizeCallTIR(const BlockBuilder& ctx, Call call) { // example, if a relax function accepted a tuple as an parameter, // then provided that same tuple as an argument to call_tir. Array tuple_elements; - size_t num_fields = Downcast(arg_expr->struct_info_)->fields.size(); + size_t num_fields = Downcast(arg_tuple->struct_info_)->fields.size(); for (size_t i = 0; i < num_fields; i++) { - tuple_elements.push_back(TupleGetItem(arg_expr, i)); + tuple_elements.push_back(TupleGetItem(arg_tuple, i)); } return Tuple(tuple_elements); }(); - auto new_args = call->args; - new_args.Set(1, new_arg_expr); - call.CopyOnWrite()->args = new_args; + if (!new_arg_tuple.same_as(arg_tuple)) { + auto new_args = call->args; + new_args.Set(1, new_arg_tuple); + call.CopyOnWrite()->args = new_args; + } return std::move(call); } +void ValidateCallTIR(Call call) { + // This function is used for validation of `relax.call_tir`, + // along with the variants `relax.call_tir_with_grad` and + // `relax.call_tir_inplace`. Therefore, all error messages should + // be written in terms of `call->op`, and should not explicitly + // reference the `relax.call_tir` operator.` + + auto callee = call->args[0]; + Expr arg_tuple = call->args[1]; + + auto packed_int_sinfo = [&]() -> Optional { + if (call->args.size() <= 2) { + return NullOpt; + } else { + return GetStructInfo(call->args[2]); + } + }(); + + auto opt_inplace_indices = [&]() -> Optional> { + if (const auto* attrs = call->attrs.as()) { + return attrs->inplace_indices; + } else { + return NullOpt; + } + }(); + + StructInfo explicit_sinfo = call->sinfo_args[0]; + auto inferred_sinfo = InferCallTIROutputStructInfoFromArguments( + GetStructInfo(callee), GetStructInfo(arg_tuple), packed_int_sinfo, opt_inplace_indices); + if (inferred_sinfo.defined()) { + CHECK(IsBaseOf(inferred_sinfo.value(), explicit_sinfo)) + << "TypeError: " + << "The `out_sinfo` argument for R.call_tir must be compatible with the PrimFunc. " + << "However, the PrimFunc's signature implies that the output should be " << inferred_sinfo + << ", but the `out_sinfo` argument was " << explicit_sinfo; + } +} + RELAY_REGISTER_OP("relax.call_tir") .set_num_inputs(3) .add_argument("func", "Expr", "The destination-passing-style function.") @@ -329,6 +571,7 @@ RELAY_REGISTER_OP("relax.call_tir") "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIR) + .set_attr("FValidate", ValidateCallTIR) .set_attr("FPurity", Bool(true)); Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, @@ -374,6 +617,7 @@ RELAY_REGISTER_OP("relax.call_tir_with_grad") "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIR) + .set_attr("FValidate", ValidateCallTIR) .set_attr("FPurity", Bool(true)); Expr MakeCallTIRWithGrad(Expr func, Tuple args, Array out_sinfo_list, @@ -514,6 +758,7 @@ RELAY_REGISTER_OP("relax.call_tir_inplace") "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) .set_attr("FNormalize", NormalizeCallTIRInPlace) + .set_attr("FValidate", ValidateCallTIR) // Warning: considered pure, but it has the potential to create visible effects! // This should only be used if it has been *checked* that it is safe (no aliases, in-place // arguments will no longer be live) diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index b203b322ab96..612e1459c826 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -1088,8 +1088,7 @@ class TIRFuseMutator : public ExprMutator { const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, old_gvar); GlobalVar new_gvar(old_gvar->name_hint); - UpdateStructInfo(new_gvar, - FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type))); + UpdateStructInfo(new_gvar, GetStructInfo(prim_func)); mod->Remove(old_gvar); updates->Add(new_gvar, prim_func); diff --git a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py index e1f45d278d6c..865051b0b4b9 100644 --- a/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py +++ b/tests/python/relax/distributed/test_distributed_transform_propagate_sharding.py @@ -512,13 +512,11 @@ def foo( cls.rotary_embedding, (lv9, cos_cached, sin_cached), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), - tir_vars=R.shape([256]), ) lv17 = R.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), - tir_vars=R.shape([256]), ) lv18: R.Tensor((256, 32, 128), dtype="float16") = R.reshape( lv17, R.shape([256, 32, 128]) @@ -712,13 +710,11 @@ def foo( cls.rotary_embedding, (lv9, cos_cached, sin_cached), out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), - tir_vars=R.shape([256]), ) lv17 = R.dist.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), - tir_vars=R.shape([256]), ) lv18: R.DTensor((256, 32, 128), "float16", "mesh[0]", "S[1]") = R.reshape( lv17, R.shape([256, 32, 128]) @@ -1278,13 +1274,11 @@ def foo( cls.rotary_embedding, (lv9, cos_cached, sin_cached), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), - tir_vars=R.shape([256]), ) lv17 = R.call_tir( cls.rotary_embedding, (lv12, cos_cached, sin_cached), out_sinfo=R.Tensor((1, 256, 32, 128), dtype="float16"), - tir_vars=R.shape([256]), ) lv18 = R.call_tir( cls.reshape1, (lv17,), out_sinfo=R.Tensor((256, 32, 128), dtype="float16") @@ -1449,13 +1443,11 @@ def foo( LlamaAttentionLayerTIR.get_global_var("rotary_embedding"), (lv9, cos_cached, sin_cached), out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), - tir_vars=R.shape([256]), ) lv17 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("rotary_embedding"), (lv12, cos_cached, sin_cached), out_sinfo=R.DTensor((1, 256, 32, 128), "float16", "mesh[0]", "S[2]"), - tir_vars=R.shape([256]), ) lv18 = R.dist.call_tir( LlamaAttentionLayerTIR.get_global_var("reshape1"), diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index 7deddfd28eb9..c0b962c3f3a0 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -14,15 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import pytest + import tvm import tvm.testing + from tvm import relax as rx from tvm import tir -from tvm.script import relax as R -from tvm.script import ir as I -from tvm.script import tir as T -from tvm.script import ir as I +from tvm.script import ir as I, relax as R, tir as T m = tir.Var("m", "int64") n = tir.Var("n", "int64") @@ -702,5 +702,511 @@ def is_bfloat16_dtype(tensor: T.handle) -> T.bool: assert rx.analysis.well_formed(Module) +def test_call_tir_with_matching_arguments(): + """R.call_tir is well-formed when called with matching arguments""" + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert rx.analysis.well_formed(Module) + + +def test_call_tir_input_ndim(): + """Arguments to R.call_tir must have the correct dimensionality + + Here, the `add_one` function expects a 1-d input tensor, but is + called with a 2-d tensor. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([4, 4], "float16")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_output_ndim(): + """Output shape R.call_tir must have the correct dimensionality + + Here, the `add_one` function requires a 1-d output tensor, but is + provided with a 2-d tensor. + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([4, 4], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_input_shape(): + """Arguments to R.call_tir must have the correct shape + + Here, the `add_one` function expects an input tensor with 16 + elements, but is called with an input tensor with 32 elements. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([32], "float16")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_output_shape(): + """Output shape R.call_tir must have the correct shape + + Here, the `add_one` function requires an output tensor with 16 + elements, but is provided an output tensor with 32 elements. + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([32], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_input_dtype(): + """Arguments to R.call_tir must have the correct dtype + + Here, the `add_one` function expects an input tensor containing + float16 value, but is called with an input tensor containing + float32 values. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float32")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float16")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_output_dtype(): + """Output shape R.call_tir must have the correct shape + + Here, the `add_one` function requires an output tensor that may be + populated with float16 values, but is provided an output tensor + that may be populated with float32 elements. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.add_one, A, out_sinfo=R.Tensor([16], "float32")) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16"), B: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_with_correct_dynamic_output_shape(): + """Output shape R.call_tir may not be verifiable + + Here, the input arguments to the `reshape` function are not + sufficient to infer the shape of the outputs. This is legal, + since the output shape is determined by the `out_sinfo` parameter. + + Inability to verify the output shape does not mean that the output + shape is invalid. + + """ + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 8], "float16")) + return B + + @T.prim_func + def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle): + M = T.int64() + N = T.int64() + B = T.match_buffer(B_handle, [M, N], dtype="float16") + + for i, j in T.grid(M, N): + with T.block("compute"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi * N + vj] + + assert rx.analysis.well_formed(Module) + + +@pytest.mark.xfail(reason="Not supported") +def test_call_tir_with_incorrect_dynamic_output_shape(): + """Output shape R.call_tir may not be verifiable + + Here, the input arguments to the `reshape` function are not + sufficient to infer the shape of the outputs. Even though the + IRModule will not provide well-defined output due to the + out-of-bounds read from buffer A, catching this error is beyond + the current scope of the Relax well-formed checker. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([16, 16], "float16")) + return B + + @T.prim_func + def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle): + M = T.int64() + N = T.int64() + B = T.match_buffer(B_handle, [M, N], dtype="float16") + + for i, j in T.grid(M, N): + with T.block("compute"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi * N + vj] + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_incorrect_dimensionality_of_output_shape(): + """Dimensionality may be verified + + Here, the input arguments to the `reshape` function are not + sufficient to infer the shape of the outputs. + + Even though the output shape may not be inferred from the input + arguments, the output dimensionality can still be inferred from + the PrimFunc signature. The IRModule below is ill-formed, because + the PrimFunc requires a 2-d output argument, but is provided with + a 3-d output argument. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([2, 4, 2], "float16")) + return B + + @T.prim_func + def reshape(A: T.Buffer(16, "float16"), B_handle: T.handle): + M = T.int64() + N = T.int64() + B = T.match_buffer(B_handle, [M, N], dtype="float16") + + for i, j in T.grid(M, N): + with T.block("compute"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi * N + vj] + + assert not rx.analysis.well_formed(Module) + + +@pytest.mark.xfail(reason="Not yet supported") +def test_call_tir_output_shape_with_mixed_static_and_dynamic(): + """Some dimensions of the R.call_tir output shape may be verifiable + + Here, the input arguments to the `reshape` function are not + sufficient to infer the shape of the outputs. This is legal, + since the output shape is taken from the `out_sinfo` parameter. + + Identifying this failure mode is not yet supported in the current + implementation. This is because the output is inferred as + `R.Tensor(ndim=3, dtype="float16")`, and the explicit `out_sinfo` + is a 3-d tensor. The mismatch in the first dimension is not yet + counted, because the entire tensor shape is removed by + `EraseToWellDefined`. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([256], "float16")): + B = R.call_tir(Module.reshape, A, out_sinfo=R.Tensor([8, 16, 2], "float16")) + return B + + @T.prim_func + def reshape(A: T.Buffer(256, "float16"), B_handle: T.handle): + M = T.int64() + N = T.int64() + B = T.match_buffer(B_handle, [16, M, N], dtype="float16") + + for i, j, k in T.grid(16, M, N): + with T.block("compute"): + vi, vj, vk = T.axis.remap("SSS", [i, j, k]) + B[vi, vj, vk] = A[vi * N * M + vj * N + vk] + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_with_correct_inferred_dynamic_output_shape(): + """Some dynamic output shapes of R.call_tir may be inferred + + Here, the `flatten` function is dynamic, and will flatten any 2-d + TIR buffer. Even though it is dynamic, the input shapes are + sufficient to infer that `M==8` and `N==4`. As a result, the + output shape of `[M*N]` can be inferred to be `[32]`, and the + shape specified in `out_sinfo` can be validated. + + """ + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([8, 4], "float16")): + B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([32], "float16")) + return B + + @T.prim_func + def flatten(A_handle: T.handle, B_handle: T.handle): + M = T.int64() + N = T.int64() + A = T.match_buffer(A_handle, [M, N], dtype="float16") + B = T.match_buffer(B_handle, [M * N], dtype="float16") + + for i in T.grid(M * N): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // N, vi % N] + + assert rx.analysis.well_formed(Module) + + +def test_call_tir_with_incorrect_inferred_dynamic_output_shape(): + """Some dynamic output shapes of R.call_tir may be inferred + + Here, the `flatten` function is dynamic, and will flatten any 2-d + TIR buffer. Even though it is dynamic, the input shapes are + sufficient to infer that `M==8` and `N==4`. As a result, the + output shape of `[M*N]` can be inferred to be `[32]`, and the + shape specified in `out_sinfo` can be validated. + + This unit test is identical to the above test + `test_call_tir_with_correct_inferred_dynamic_output_shape`, except + that the output shape is explicitly specified as `[64]`, which is + caught as a mismatch from the expected output shape. + + """ + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([8, 4], "float16")): + B = R.call_tir(Module.flatten, A, out_sinfo=R.Tensor([64], "float16")) + return B + + @T.prim_func + def flatten(A_handle: T.handle, B_handle: T.handle): + M = T.int64() + N = T.int64() + A = T.match_buffer(A_handle, [M, N], dtype="float16") + B = T.match_buffer(B_handle, [M * N], dtype="float16") + + for i in T.grid(M * N): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // N, vi % N] + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_with_dtensor_arguments(): + """R.call_tir and R.dist.call_tir share the same operation + + Both `R.call_tir` and `R.dist.call_tir` produce the same + "relax.call_tir" operation, differing only in the StructInfo of + their arguments. Normalization of "relax.call_tir" must handle + `R.DTensor` arguments. + + """ + + # from tvm.script.parser import relax as R + + @I.ir_module + class Module: + I.module_attrs({"device_num": 4}) + I.module_global_infos({"mesh": [R.dist.device_mesh([4], I.Range(0, 4))]}) + + @R.function + def main(A: R.dist.DTensor([8, 4], "float16", "mesh[0]", "S[0]")): + B = R.dist.call_tir( + Module.flatten, A, out_sinfo=R.dist.DTensor([64], "float16", "mesh[0]", "S[0]") + ) + return B + + @T.prim_func + def flatten(A_handle: T.handle, B_handle: T.handle): + M = T.int64() + N = T.int64() + A = T.match_buffer(A_handle, [M, N], dtype="float16") + B = T.match_buffer(B_handle, [M * N], dtype="float16") + + for i in T.grid(M * N): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // N, vi % N] + + assert rx.analysis.well_formed(Module) + + +def test_call_tir_inplace_with_correct_shapes(): + """R.call_tir_inplace is well-formed when called with matching arguments""" + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir_inplace( + Module.add_one, + A, + inplace_indices=[0], + out_sinfo=R.Tensor([16], "float16"), + ) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + A[vi] = A[vi] + T.float16(1.0) + + assert rx.analysis.well_formed(Module) + + +def test_call_tir_inplace_with_incorrect_shapes(): + """R.call_tir_inplace is ill-formed when output shape does not match input""" + + @I.ir_module(check_well_formed=False) + class Module: + @R.function + def main(A: R.Tensor([16], "float16")): + B = R.call_tir_inplace( + Module.add_one, + A, + inplace_indices=[0], + out_sinfo=R.Tensor([32], "float16"), + ) + return B + + @T.prim_func + def add_one(A: T.Buffer(16, "float16")): + for i in range(16): + with T.block("compute"): + vi = T.axis.remap("S", [i]) + A[vi] = A[vi] + T.float16(1.0) + + assert not rx.analysis.well_formed(Module) + + +def test_call_tir_inplace_with_some_allocated_outputs(): + """R.call_tir_inplace may contain some non-inplace outputs""" + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor([16], "float16"), B: R.Tensor([32], "float16")): + out = R.call_tir_inplace( + Module.add_one, + (A, B), + inplace_indices=[-1, 1], + out_sinfo=[ + R.Tensor([16], "float16"), + R.Tensor([32], "float16"), + ], + ) + return out + + @T.prim_func + def add_one( + A: T.Buffer(16, "float16"), + B: T.Buffer(32, "float16"), + C: T.Buffer(16, "float16"), + ): + for i in range(32): + with T.block("inplace_B"): + vi = T.axis.remap("S", [i]) + B[vi] = B[vi] + T.float16(1.0) + + for i in range(16): + with T.block("output_C"): + vi = T.axis.remap("S", [i]) + C[vi] = A[vi] + T.float16(1.0) + + assert rx.analysis.well_formed(Module) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 64d5c7381171..6005ecb0fa58 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -43,6 +43,7 @@ def normalize(func: rx.Function) -> rx.Function: """ Normalize the expr to fill in the checked_type_ and struct_info fields everywhere """ + # using a default mutator to use the BlockBuilder's normalizer, # which oddly differs from the Normalize pass @rx.expr_functor.mutator @@ -435,9 +436,13 @@ def test_call_tir(): @tvm.script.ir_module class TestCallTIR: @T.prim_func - def addone(A: T.Buffer((16, 16), "int32"), B: T.Buffer((16, 16), "int32")) -> None: + def addone(A_handle: T.handle, B_handle: T.handle) -> None: + m = T.int64() + n = T.int64() + A = T.match_buffer(A_handle, (m, n), "float32") + B = T.match_buffer(B_handle, (m, n), "float32") T.func_attr(({"global_symbol": "addone"})) - for i, j in T.grid(16, 16): + for i, j in T.grid(m, n): with T.block("addone"): vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] + T.int32(1) diff --git a/tests/python/relax/test_dataflow_inplace.py b/tests/python/relax/test_dataflow_inplace.py index 8d5eb07c7858..cd6e285de499 100644 --- a/tests/python/relax/test_dataflow_inplace.py +++ b/tests/python/relax/test_dataflow_inplace.py @@ -172,8 +172,8 @@ def tir_id(x: T.handle, y: T.handle) -> None: T.func_attr({"global_symbol": "tir_id"}) m = T.int32() n = T.int32() - A = T.match_buffer(x, (m, n)) - B = T.match_buffer(y, (m, n)) + A = T.match_buffer(x, (m, n), "int32") + B = T.match_buffer(y, (m, n), "int32") for i, j in T.grid(m, n): with T.block("id"): @@ -185,9 +185,9 @@ def tir_id2(x: T.handle, y: T.handle, z: T.handle) -> None: T.func_attr({"global_symbol": "tir_id"}) m = T.int32() n = T.int32() - A = T.match_buffer(x, (m, n)) - B = T.match_buffer(y, (m, n)) - C = T.match_buffer(z, (m, n)) + A = T.match_buffer(x, (m, n), "int32") + B = T.match_buffer(y, (m, n), "int32") + C = T.match_buffer(z, (m, n), "int32") for i, j in T.grid(m, n): with T.block("id"): diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 03a3beb2f27e..7a3b65cea10e 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -72,7 +72,7 @@ def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> lv0 = R.call_tir(cls.tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) lv1 = R.call_tir(cls.tir_relu, (lv0), R.Tensor((32, 32), dtype="float32")) lv2 = R.call_tir( - cls.tir_zeros, (lv1), R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32]) + cls.tir_zeros, [], R.Tensor((32,), dtype="float32"), tir_vars=R.ShapeExpr([32]) ) gv = (lv1, lv2) R.output(gv) diff --git a/tests/python/relax/test_frontend_dynamo.py b/tests/python/relax/test_frontend_dynamo.py index d83f83f4e188..21e1d82d28b5 100644 --- a/tests/python/relax/test_frontend_dynamo.py +++ b/tests/python/relax/test_frontend_dynamo.py @@ -114,9 +114,10 @@ def main( with db: opt_model = torch.compile(model, backend=relax_dynamo()) inp = torch.randn(10, 100) - tvm.testing.assert_allclose( - opt_model(inp).detach().numpy(), model(inp).detach().numpy(), rtol=1e-5, atol=1e-5 - ) + + default_output = model(inp).detach().numpy() + optimized_output = opt_model(inp).detach().numpy() + tvm.testing.assert_allclose(optimized_output, default_output, rtol=1e-5, atol=1e-5) def test_relax_dynamo_dynamic(): diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 6c3269195498..4990c80e895e 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -570,10 +570,18 @@ def test_tensor_ir_op(): @T.prim_func(private=True) def fused_rope( # pylint: disable=too-many-locals var_qkv: T.handle, - offset: T.int64, var_q: T.handle, var_k: T.handle, var_v: T.handle, + # Scalar arguments must be specified after tensor arguments, + # including the output tensor arguments + # + # TODO(Lunderberg): Update + # `tvm.relax.frontend.nn.op.tensor_ir_op` to use `PrimValue` + # instead of `tir_vars`, so that the order can be consistent + # between the function definition and the arguments in + # `op.tensor_ir_op`. + offset: T.int64, ): batch_size = T.int64() seq_len = T.int64() @@ -601,7 +609,7 @@ def test(self, qkv: Tensor, offset: tir.Var): @I.ir_module class Expected: @T.prim_func(private=True) - def llama_fused_rope(var_qkv: T.handle, offset: T.int64, var_q: T.handle, var_k: T.handle, var_v: T.handle): + def llama_fused_rope(var_qkv: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, offset: T.int64): batch_size, seq_len = T.int64(), T.int64() qkv = T.match_buffer(var_qkv, (batch_size, seq_len, 24, 16), "float16") q = T.match_buffer(var_q, (batch_size, seq_len, 8, 16), "float16") @@ -669,10 +677,11 @@ class Model(Module): def test( self, embedding_table: Tensor, input_ids: Tensor, embedding_dst: Tensor, offset: int ): - tensor_expr_op_out = op.tensor_ir_op( + tensor_expr_op_out = op.tensor_ir_inplace_op( inplace_take, "inplace_take", args=[embedding_table, input_ids, embedding_dst, offset], + inplace_indices=[2], out=Tensor.placeholder(embedding_dst.shape, embedding_dst.dtype), ) return tensor_expr_op_out @@ -719,10 +728,11 @@ def test( R.func_attr({"num_input": 4}) cls = Expected with R.dataflow(): - lv1 = R.call_tir( + lv1 = R.call_tir_inplace( cls.inplace_take, (embedding_table, input_ids, embedding_dst), out_sinfo=R.Tensor((total_seq_len, hidden_size), dtype), + inplace_indices=[2], tir_vars=R.shape([offset_1]), ) gv1: R.Tensor((total_seq_len, hidden_size), dtype) = lv1 diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index ee2df866fb35..e3274aea886a 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -86,7 +86,11 @@ def test_call_tir_rewrite(): @tvm.script.ir_module class TestCallTIRRewrite: @T.prim_func - def exp(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + def exp(A_handle: T.handle, B_handle: T.handle): + m = T.int64() + n = T.int64() + A = T.match_buffer(A_handle, (m, n), "float32") + B = T.match_buffer(B_handle, (m, n), "float32") T.evaluate(0) @R.function diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 142faf51607b..fb915262cc81 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -277,18 +277,26 @@ def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16), "float32"): def test_unused_relax_func_symbolic_shape(): # Test with relax function w/ symbolic shape. - @tvm.script.ir_module + @tvm.script.ir_module(check_well_formed=False) class InputModule: @T.prim_func - def tir_add( - x: T.Buffer((16, 16), "float32"), - y: T.Buffer((16, 16), "float32"), - z: T.Buffer((16, 16), "float32"), + def tir_matmul( + x_handle: T.handle, + y_handle: T.handle, + z_handle: T.handle, ) -> None: - for i, j in T.grid(16, 16): - with T.block("add"): - vi, vj = T.axis.remap("SS", [i, j]) - z[vi, vj] = x[vi, vj] + y[vi, vj] + m = T.int64() + n = T.int64() + k = T.int64() + x = T.match_buffer(x_handle, (m, n), "float32") + y = T.match_buffer(y_handle, (n, k), "float32") + z = T.match_buffer(z_handle, (m, k), "float32") + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + z[vi, vj] = 0.0 + z[vi, vj] = z[vi, vj] + x[vi, vk] * y[vk, vj] @R.function(private=True) def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): @@ -298,7 +306,7 @@ def unused_func(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "flo @R.function def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")): m, k = T.int64(), T.int64() - gv0 = R.call_tir(InputModule.tir_add, (x, w), R.Tensor((m + 1, k), dtype="float32")) + gv0 = R.call_tir(InputModule.tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) return gv0 mod = InputModule @@ -306,7 +314,7 @@ def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")) new_mod = DeadCodeElimination()(mod) assert check_if_func_exists(new_mod, "main") - assert check_if_func_exists(new_mod, "tir_add") + assert check_if_func_exists(new_mod, "tir_matmul") assert not check_if_func_exists(new_mod, "unused_func") diff --git a/tests/python/relax/test_transform_fuse_ops.py b/tests/python/relax/test_transform_fuse_ops.py index 17bf58613294..9ad66bec012a 100644 --- a/tests/python/relax/test_transform_fuse_ops.py +++ b/tests/python/relax/test_transform_fuse_ops.py @@ -875,7 +875,7 @@ class Module: def main(x: R.Tensor((1, 512, 64, 64), "float32"), mean: R.Tensor((64, 64), "float32"), var: R.Tensor((64, 64), "float32")): cls = Module with R.dataflow(): - gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64))) + gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64), 'float32')) gv1 = R.call_tir(cls.relu, gv0, out_sinfo=R.Tensor((1, 512, 64, 64), "float32")) R.output(gv1) return gv1 @@ -955,7 +955,7 @@ def fused_layer_norm_relu(x: R.Tensor((1, 512, 64, 64), dtype="float32"), mean: R.func_attr({"Primitive": 1}) cls = Expected with R.dataflow(): - gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64))) + gv0 = R.call_tir(cls.layer_norm, (x, mean, var), out_sinfo=R.Tensor((1, 512, 64, 64), 'float32')) gv = R.call_tir(cls.relu, (gv0,), out_sinfo=R.Tensor((1, 512, 64, 64), dtype="float32")) R.output(gv) return gv @@ -1452,7 +1452,7 @@ def main( R.Tensor((2,), "float32"), R.Tensor((2,), "float32"), R.Tensor((2,), "float32"), - ) + ), ): with R.dataflow(): x0 = x[0] @@ -1486,7 +1486,7 @@ def main( R.Tensor((2,), dtype="float32"), R.Tensor((2,), dtype="float32"), R.Tensor((2,), dtype="float32"), - ) + ), ) -> R.Tensor((2,), dtype="float32"): cls = Expected with R.dataflow(): diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index 1582526042f1..a07875fcdae6 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -696,10 +696,10 @@ def test_ignore_call_tir(): class Conv2dReLUCallTIR: @T.prim_func def relu( - data: T.Buffer((64, 64, 56, 56), "float32"), - out: T.Buffer((64, 64, 56, 56), "float32"), + data: T.Buffer((1, 64, 56, 56), "float32"), + out: T.Buffer((1, 64, 56, 56), "float32"), ): - for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56): + for ax0, ax1, ax2, ax3 in T.grid(1, 64, 56, 56): with T.block("root"): i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) out[i, j, k, l] = T.max(data[i, j, k, l], 0.0) @@ -714,7 +714,7 @@ def main( relu1 = R.call_tir( Conv2dReLUCallTIR.relu, (conv1,), - R.Tensor((64, 64, 56, 56), "float32"), + R.Tensor((1, 64, 56, 56), "float32"), ) R.output(relu1) @@ -724,11 +724,11 @@ def main( class Conv2dReLUCallTIR_partitioned: @T.prim_func def relu( - data: T.Buffer((64, 64, 56, 56), "float32"), - out: T.Buffer((64, 64, 56, 56), "float32"), + data: T.Buffer((1, 64, 56, 56), "float32"), + out: T.Buffer((1, 64, 56, 56), "float32"), ): # with T.block("root"): - for ax0, ax1, ax2, ax3 in T.grid(64, 64, 56, 56): + for ax0, ax1, ax2, ax3 in T.grid(1, 64, 56, 56): with T.block("root"): i, j, k, l = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(data[i, j, k, l]) @@ -754,7 +754,7 @@ def fused_relax_nn_conv2d( def main( data: R.Tensor((1, 64, 56, 56), dtype="float32"), weight1: R.Tensor((64, 64, 3, 3), dtype="float32"), - ) -> R.Tensor((64, 64, 56, 56), dtype="float32"): + ) -> R.Tensor((1, 64, 56, 56), dtype="float32"): cls = Conv2dReLUCallTIR_partitioned with R.dataflow(): lv: R.Tensor((1, 64, 56, 56), dtype="float32") = cls.fused_relax_nn_conv2d( @@ -763,7 +763,7 @@ def main( relu1 = R.call_tir( cls.relu, (lv,), - out_sinfo=R.Tensor((64, 64, 56, 56), dtype="float32"), + out_sinfo=R.Tensor((1, 64, 56, 56), dtype="float32"), ) R.output(relu1) return relu1 diff --git a/tests/python/relax/test_transform_lazy_transform_params.py b/tests/python/relax/test_transform_lazy_transform_params.py index 278ac825f7a7..87a5698f1bf8 100644 --- a/tests/python/relax/test_transform_lazy_transform_params.py +++ b/tests/python/relax/test_transform_lazy_transform_params.py @@ -43,7 +43,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -124,7 +124,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -209,7 +209,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -298,7 +298,7 @@ def transform_layout_IOHW_to_OIHW( def main_transform_params( params: R.Tuple( R.Tensor((3, 16, 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32") - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor((16, 3, 3, 3), dtype="float32") ): @@ -441,8 +441,8 @@ def main_transform_params( @T.prim_func(private=True) def slice_buffer( Input: T.Buffer((16, 16), "float32"), - slice_index: T.int64, Output: T.Buffer(16, "float32"), + slice_index: T.int64, ): for i in T.grid(16): with T.block("slice_buffer"): @@ -479,8 +479,8 @@ def main_transform_params(slice_shape_expr: R.Shape(["slice_index"])): @T.prim_func(private=True) def slice_buffer( Input: T.Buffer((16, 16), "float32"), - slice_index: T.int64, Output: T.Buffer(16, "float32"), + slice_index: T.int64, ): for i in T.grid(16): with T.block("slice_buffer"): @@ -511,7 +511,7 @@ def main_transform_params( params: R.Tuple( R.Tensor((3, "ic", 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32"), - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32") ): @@ -637,7 +637,7 @@ def transform_params( params: R.Tuple( R.Tensor((3, "ic", 3, 3), dtype="float32"), R.Tensor((16, 16, 3, 3), dtype="float32"), - ) + ), ) -> R.Tuple( R.Tensor((16, 16, 3, 3), dtype="float32"), R.Tensor(("ic", 3, 3, 3), dtype="float32") ): @@ -691,7 +691,7 @@ def test_duplicate_outputs(): class Before: @R.function def main_transform_params( - params: R.Tuple(R.Tensor([16], dtype="int32"), R.Tensor([16], dtype="int32")) + params: R.Tuple(R.Tensor([16], dtype="int32"), R.Tensor([16], dtype="int32")), ): R.func_attr({"relax.force_pure": True}) param0 = params[0] @@ -966,7 +966,7 @@ def transform_params( class Expected: @R.function def transform_params( - fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object) + fget_param: R.Callable([R.Prim("int64"), R.Object], R.Object), ) -> R.Tuple(R.Tensor(ndim=2, dtype="float32"), R.Tensor(ndim=2, dtype="float32")): R.func_attr({"num_input": 1}) m = T.int64() diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py index f7befd3b886a..5a7d76d8fe41 100644 --- a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py +++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py @@ -252,11 +252,15 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): ] @R.function - def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), dtype="float32"): + def main( + x: R.Tensor((8, 16, 128), dtype="float16") + ) -> R.Tensor((1, 8, 16, 128), dtype="float16"): cls = Module with R.dataflow(): - y = R.call_tir(cls.reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32")) - z = R.add(y, R.const(1, "float32")) + y = R.call_tir( + cls.reshape, (x,), out_sinfo=R.Tensor((1, 8, 16, 128), dtype="float16") + ) + z = R.add(y, R.const(1, "float16")) R.output(z) return z @@ -290,10 +294,14 @@ def reshape(var_A: T.handle, var_T_reshape: T.handle): ] @R.function - def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), dtype="float32"): + def main( + x: R.Tensor((8, 16, 128), dtype="float16") + ) -> R.Tensor((1, 8, 16, 128), dtype="float16"): with R.dataflow(): - y: R.Tensor((2, 4, 3), dtype="float32") = R.reshape(x, R.shape([2, 4, 3])) - z: R.Tensor((2, 4, 3), dtype="float32") = R.add(y, R.const(1, "float32")) + y: R.Tensor((1, 8, 16, 128), dtype="float16") = R.reshape( + x, R.shape([1, 8, 16, 128]) + ) + z: R.Tensor((1, 8, 16, 128), dtype="float16") = R.add(y, R.const(1, "float16")) R.output(z) return z @@ -383,7 +391,7 @@ def main( R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), - ) + ), ) -> R.Tensor((2, 4096, 8, 40), dtype="float16"): cls = Module with R.dataflow(): @@ -444,7 +452,7 @@ def main( R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), R.Tensor((2, 4096, 320), dtype="float16"), - ) + ), ) -> R.Tensor((2, 4096, 8, 40), dtype="float16"): with R.dataflow(): lv: R.Tensor((2, 4096, 320), dtype="float16") = lv41_1[0] @@ -735,7 +743,6 @@ def add( z_handle: T.handle, N: T.int64, ): - y1 = T.match_buffer(y1_handle, [N * 4, T.int64(4)], "float32") y2 = T.match_buffer(y2_handle, [N * 4, T.int64(4)], "float32") z = T.match_buffer(z_handle, [N * 4, T.int64(4)], "float32") diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 4f41b662caf2..deb751738946 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -77,7 +77,7 @@ def test_mismatch_cast_dims_and_ndim(): @R.function def f( - x: R.Tensor((2, 3), "float32", ndim=3) + x: R.Tensor((2, 3), "float32", ndim=3), ): # error: ndim and the shape dims are mismatch return x @@ -961,11 +961,11 @@ def test_call_tir_with_tir_var(): class Module: @R.function def main( - dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2", "float32")) + dumb_param: R.Tensor(("n",), "float32"), x: R.Tensor(("n * 2",), "float32") ) -> R.Tensor(("n * 2",), "float32"): n = T.int64() cls = Module - y = R.call_tir(cls.copy, (x,), R.Tensor(((n * 2,)), dtype="float32"), tir_vars=(n,)) + y = R.call_tir(cls.copy, x, R.Tensor((n * 2,), dtype="float32"), tir_vars=(n,)) return y @T.prim_func @@ -2135,7 +2135,9 @@ def func(z: R.Tensor((4, 4), "float32")): @R.function(private=True) def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]): alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor( - R.shape([4, 4]), R.dtype("float32"), R.prim_value(2) # Make sure prim_value is 2 + R.shape([4, 4]), + R.dtype("float32"), + R.prim_value(2), # Make sure prim_value is 2 ) shape: R.Shape([4, 4]) = R.shape_of(alloc) shape_1: R.Shape([4, 4]) = shape @@ -2167,7 +2169,9 @@ def func(z: R.Tensor((4, 4), "float32")): @R.function(private=True) def expect(z: R.Tensor((4, 4), dtype="float32")) -> R.Shape([4, 4]): alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor( - R.shape([4, 4]), R.dtype("float32"), R.prim_value(1) # Make sure prim_value is 1 + R.shape([4, 4]), + R.dtype("float32"), + R.prim_value(1), # Make sure prim_value is 1 ) shape: R.Shape([4, 4]) = R.shape_of(alloc) shape_1: R.Shape([4, 4]) = shape @@ -2336,7 +2340,6 @@ def explicit_sinfo( B: R.Tensor(["N"], "float32"), cond: R.Prim("bool"), ) -> R.Tensor(["N"], "float32"): - N = T.int64() if cond: diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 30fd06d4f14d..ecf33aa9da1e 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -988,8 +988,10 @@ class ModA: I.module_attrs({"system_lib_prefix": "libA_"}) @T.prim_func - def tir_init(x: T.Buffer((2), "float32")) -> None: - for i in range(2): + def tir_init(x_handle: T.handle): + N = T.int64() + x = T.match_buffer(x_handle, [N], "float32") + for i in range(N): x[i] = T.float32(0) @R.function @@ -1003,8 +1005,10 @@ class ModB: I.module_attrs({"system_lib_prefix": "libB_"}) @T.prim_func - def tir_init(x: T.Buffer((2), "float32")) -> None: - for i in range(2): + def tir_init(x_handle: T.handle): + N = T.int64() + x = T.match_buffer(x_handle, [N], "float32") + for i in range(N): x[i] = T.float32(1) @R.function