From 27a6820d9491ce459e2044ed0d3e6b98a2300fac Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 16 Apr 2024 09:45:45 -0500 Subject: [PATCH] [Relax] Implement relax.transform.RemoveSymbolicExpressionsInSubroutine This is a follow-up commit to https://github.com/apache/tvm/pull/16637, which updated `relax.transform.FuseOps` to provide additional parameters defining symbolic variables required by the fused functions. While this ensures that `relax.transform.FuseOps` produces well-formed Relax functions, these additional arguments can break some kernel implementations. This commit implements a new transform `RemoveSymbolicExpressionsInSubroutine` to resolve this issue. This transform identifies function arguments whose sole purpose is to compute a symbolic expression, when that symbolic expression could be inferred from tensor shapes. For example, consider the following Relax function: ```python @R.function def func( data: R.Tensor(["batch_size * seq_len", "hidden_size"]), weights: R.Tensor(["hidden_size", "intermediate_size"]), dummy_arg: R.Shape(["batch_size", "seq_len"]), ) -> R.Tensor(["batch_size * seq_len", "intermediate_size"]): batch_size = T.int64() seq_len = T.int64() intermediate_size = T.int64() hidden_size = T.int64() output: R.Tensor([batch_size * seq_len, intermediate_size]) = R.matmul(data, weights) return output ``` The `data` tensor may be used to infer `hidden_size`, but cannot be used to infer `batch_size` or `seq_len`. The `R.Shape` parameter exists solely to define `batch_size` and `seq_len`, since all symbolic variables must be defined. However, neither `batch_size` nor `seq_len` are ever used outside of the expression `batch_size * seq_len`, and the value of `batch_size * seq_len` could be inferred from the shape of the `data` tensor. This new transform identifies cases where an argument is otherwise unnecessary, and replaces the symbolic expression with a new argument. This makes the `dummy_arg: R.Shape` be entirely unused, so a later use of `relax.transform.RemoveUnusedParameters()` can remove the parameter altogether. ```python @R.function def func( data: R.Tensor(["data_dim0", "hidden_size"]), weights: R.Tensor(["hidden_size", "intermediate_size"]), dummy_arg: R.Shape(["batch_size", "seq_len"]), ): data_dim0 = T.int64() intermediate_size = T.int64() hidden_size = T.int64() output: R.Tensor([data_dim0, intermediate_size]) = R.matmul(data, weights) return output ``` --- include/tvm/relax/transform.h | 17 + python/tvm/relax/transform/__init__.py | 3 +- python/tvm/relax/transform/transform.py | 52 +++ ...emove_symbolic_expression_in_subroutine.cc | 247 +++++++++++ .../transform/remove_unused_parameters.cc | 16 +- ...emove_symbolic_expression_in_subroutine.py | 386 ++++++++++++++++++ ...test_transform_remove_unused_parameters.py | 7 +- 7 files changed, 721 insertions(+), 7 deletions(-) create mode 100644 src/relax/transform/remove_symbolic_expression_in_subroutine.cc create mode 100644 tests/python/relax/test_transform_remove_symbolic_expression_in_subroutine.py diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 5a7b85ac1376..27ac65d3f760 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -307,6 +307,23 @@ TVM_DLL Pass RemoveUnusedParameters(); */ TVM_DLL Pass RemoveUnusedOutputs(); +/*! \brief Remove unnecessary symbolic expressions in subroutines + * + * If all occurrences of a symbolic variable within a subroutine + * occur within the same symbolic expression, then the subroutine + * could be simplified to be in terms of that expression. + * + * For example, if a subroutine accepts symbolic shape parameters `N` + * and `M`, and the variables `N` and `M` are only ever used to + * compute `N*M`, then the subroutine could instead accept a symbolic + * shape parameter `new_var = N*M`. This can allow shape parameters + * to be inferred from tensor shapes, rather than requiring additional + * arguments. + * + * \return The pass + */ +TVM_DLL Pass RemoveSymbolicExpressionInSubroutine(); + /*! * \brief Annotate Op Pattern Kind for TIR functions, which is used in FuseOps. * \note It is an auto-detect pass for "unscheduled prim_funcs", the op_pattern will be diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 1ce864651cd9..e8c70eca26c7 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Relax transformations. """ +"""Relax transformations.""" from .transform import ( AdjustMatmulOrder, @@ -65,6 +65,7 @@ PatternCheckContext, RealizeVDevice, RemovePurityChecking, + RemoveSymbolicExpressionInSubroutine, RemoveUnusedOutputs, RemoveUnusedParameters, ReorderPermuteDimsAfterConcat, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 95649f331f33..52e2215ab18e 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -764,6 +764,58 @@ def RemoveUnusedOutputs() -> tvm.ir.transform.Pass: return _ffi_api.RemoveUnusedOutputs() # type: ignore +def RemoveSymbolicExpressionInSubroutine() -> tvm.ir.transform.Pass: + """Remove unnecessary symbolic expressions in subroutines + + If all occurrences of a symbolic variable within a subroutine + occur within the same symbolic expression, then the subroutine + could be simplified to be in terms of that expression. + + For example, consider an elementwise operation that takes input of + shape `arg: R.Tensor([m * n])`, producing output of shape + `R.Tensor([m * n])`. The symbolic variables `m` and `n` cannot be + inferred from the shape of `arg`, as only their product `m*n` can + be determined from the tensor's shape. In order to be + well-formed, Relax requires one of the three following + workarounds. + + 1. Remove the symbolic variables, producing `arg: + R.Tensor(ndim=1)`. This no longer provides the symbolic + variables, and is well-formed. However, this also causes the + output shape to be `R.Tensor(ndim=1)`. The calling scope can + no longer determine that the input and output shape are + identical. + + This is the default behavior of the `relax::BlockBuilder` + + 2. Provide an additional argument to define the symbolic variable. + If the elementwise operation takes an addition argument + `R.Shape([m, n])`, then that additional argument would + define the symbolic variables. + + This is the output produced by `relax.transform.FuseOps`, and + while it is well-formed, the additional non-tensor argument can + be unexpected by downstream transforms. + + 3. Update the shape of `arg` to `R.Tensor([arg_size])`. This + allows the symbolic variable `arg_size` to be inferred from the + tensor's shape, and propagates to the output shape of + `R.Tensor([arg_size])`. Within the calling scope, an + argument of `R.Tensor([m * n])` can then be inferred to produce + an output of `R.Tensor([m * n])`, without requiring an + additional parameter to provide the shape. + + This transform updates internal function that use option (2) to + instead use option (3). + + Returns + ------- + ret: tvm.ir.transform.Pass + + """ + return _ffi_api.RemoveSymbolicExpressionInSubroutine() # type: ignore + + def InlinePrivateFunctions() -> tvm.ir.transform.Pass: """Inline all private relax functions diff --git a/src/relax/transform/remove_symbolic_expression_in_subroutine.cc b/src/relax/transform/remove_symbolic_expression_in_subroutine.cc new file mode 100644 index 000000000000..2782bbd922f2 --- /dev/null +++ b/src/relax/transform/remove_symbolic_expression_in_subroutine.cc @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/transform/remove_symbolic_expression_in_subroutine.cc + * + * \brief Replace symbolic expressions with single variables, when possible. + * + * For example, if a subroutine accepts symbolic shape parameters `N` + * and `M`, and the variables `N` and `M` are only ever used to + * compute `N*M`, then the subroutine could instead accept a symbolic + * shape parameter `new_var = N*M`. This can allow shape parameters + * to be inferred from tensor shapes, rather than requiring additional + * arguments. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relax { + +namespace { + +// Utility templates for unordered map/set that use structural hash/equal. + +template +using StructMap = std::unordered_map; + +template +using StructSet = std::unordered_set; + +/* \brief Collect symbolic expressions that may be inferred from a function signature + * + * \param func The function whose signature should be inspected + * + * \return A map from PrimExpr to the location where it occurs in the signature + */ +StructMap CollectInferableExpressions(const Function& func) { + StructMap output; + + auto mark = [&](const PrimExpr& expr, const ObjectPath& path) { + if (!output.count(expr)) { + std::stringstream ss; + ss << path; + output[expr] = ss.str(); + } + }; + + std::function visit = [&](const StructInfo& sinfo, + const ObjectPath& path) { + if (auto tensor = sinfo.as()) { + if (auto opt_shape = tensor->GetShape()) { + auto shape_path = path->Attr("shape"); + auto shape = opt_shape.value(); + for (size_t i = 0; i < shape.size(); i++) { + mark(shape[i], shape_path->ArrayIndex(i)); + } + } + } else if (auto tuple = sinfo.as()) { + for (size_t i = 0; i < tuple->fields.size(); i++) { + visit(tuple->fields[i], path->ArrayIndex(i)); + } + } + }; + + for (const auto& param : func->params) { + visit(GetStructInfo(param), ObjectPath::Root(param->name_hint())); + } + + return output; +} + +/* \brief Collect expressions that are required in a function body + * + * This recurses into StructInfo and sub-expressions, but does not + * recurse beyond any expression in `inferable_expressions`. This + * allows the transform to determine whether a `tir::Var` ever occurs + * outside of an expression that can be inferred. + */ +class RequiredExpressionCollector : private StructInfoVisitor, + private ExprVisitor, + private tir::ExprVisitor { + public: + static StructSet Collect( + const Function& func, const StructMap& inferable_expressions) { + RequiredExpressionCollector visitor(inferable_expressions); + visitor.VisitExpr(func->body); + return visitor.required_expressions_; + } + + private: + explicit RequiredExpressionCollector( + const StructMap& inferable_expressions) + : inferable_expressions_(inferable_expressions) {} + + using relax::ExprVisitor::VisitExpr; + using tir::ExprVisitor::VisitExpr; + + // Required in order to recurse from `TensorStructInfo` into its + // `ShapeExpr`. This hands control from `StructInfoVisitor` into + // `ExprVisitor`. + void VisitStructInfoExprField(const Expr& expr) override { VisitExpr(expr); } + + // Required in order to recurse into `ShapeStructInfo`. This hands + // control from `ExprVisitor` back to `StructInfoVisitor`. + void VisitExprDepStructInfoField(const StructInfo& struct_info) override { + VisitStructInfo(struct_info); + } + + void VisitPrimExpr(const PrimExpr& expr) override { + required_expressions_.insert(expr); + if (!inferable_expressions_.count(expr)) { + tir::ExprVisitor::VisitExpr(expr); + } + } + + void VisitStructInfoExprField(const PrimExpr& expr) override { VisitPrimExpr(expr); } + + const StructMap& inferable_expressions_; + StructSet required_expressions_; +}; + +/* \brief Replace occurrences of a PrimExpr in the symbolic variables + * + * In most cases, the `tvm::relax::Bind` utility should be used + * instead. Here, though, we are replacing a `PrimExpr` with a + * `tir::Var`, whereas `tvm::relax::Bind` supports the more standard + * case of replacing a `tir::Var` with a `PrimExpr`. + */ +class SymbolicSubexprReplacer : public relax::ExprMutator, + public StructInfoMutator, + public tir::ExprMutator { + public: + using relax::ExprMutator::operator(); + using relax::ExprMutator::VisitExpr; + using tir::ExprMutator::operator(); + using tir::ExprMutator::VisitExpr; + + explicit SymbolicSubexprReplacer(StructMap replacements) + : replacements_(replacements) {} + + StructInfo VisitExprDepStructInfoField(const StructInfo& struct_info) override { + return VisitStructInfo(struct_info); + } + Expr VisitStructInfoExprField(const Expr& expr) override { return VisitExpr(expr); } + PrimExpr VisitStructInfoExprField(const PrimExpr& expr) override { return VisitExpr(expr); } + PrimExpr VisitPrimExpr(const PrimExpr& expr) override { return VisitExpr(expr); } + + PrimExpr VisitExpr(const PrimExpr& expr) override { + if (auto it = replacements_.find(expr); it != replacements_.end()) { + return it->second; + } else { + return tir::ExprMutator::VisitExpr(expr); + } + } + + StructMap replacements_; +}; + +} // namespace + +Function RemoveSymbolicExpressionInSubroutine(Function func) { + bool is_exposed_externally = func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + if (is_exposed_externally) return func; + + auto inferable_expressions = CollectInferableExpressions(func); + + auto required_expressions = RequiredExpressionCollector::Collect(func, inferable_expressions); + + StructMap replacements; + for (const auto& [expr, name] : inferable_expressions) { + bool is_tir_var = expr->IsInstance(); + + auto expr_depends_on = tir::UndefinedVars(expr); + bool internal_variable_is_required = + std::any_of(expr_depends_on.begin(), expr_depends_on.end(), + [&](const tir::Var& subvar) { return required_expressions.count(subvar); }); + + if (!is_tir_var && !internal_variable_is_required) { + // For human-readability, use the location used to infer the + // shape to name the variable. (e.g. `A_dim0` for a parameter + // inferred from parameter `A->shape[0]`.) + replacements[expr] = tir::Var(name, expr->dtype); + } + } + + if (replacements.empty()) { + return func; + } + + SymbolicSubexprReplacer mutator(replacements); + return Downcast(mutator(func)); +} + +namespace transform { +Pass RemoveSymbolicExpressionInSubroutine() { + auto pass_func = [=](IRModule mod, PassContext pc) -> IRModule { + IRModule updates; + + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + auto mutated = RemoveSymbolicExpressionInSubroutine(func.value()); + if (!mutated.same_as(base_func)) { + updates->Add(gvar, mutated); + } + } + } + + if (updates->functions.size()) { + mod.CopyOnWrite()->Update(updates); + } + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "RemoveSymbolicExpressionInSubroutine", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.RemoveSymbolicExpressionInSubroutine") + .set_body_typed(RemoveSymbolicExpressionInSubroutine); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/remove_unused_parameters.cc b/src/relax/transform/remove_unused_parameters.cc index d053d56f3205..f29281edf2ed 100644 --- a/src/relax/transform/remove_unused_parameters.cc +++ b/src/relax/transform/remove_unused_parameters.cc @@ -84,17 +84,25 @@ std::optional AnalyzeCallee(Function func) { // symbolic variables. We still want to remove the relax variable // to reduce computational steps in the parent, but we need to // provide the symbolic variables the other steps. - auto defined_tir_params = [&]() -> PSet { + auto required_tir_vars = [&]() -> PSet { + auto arr = FreeSymbolicVars(func->body); + return {arr.begin(), arr.end()}; + }(); + + auto inferable_tir_params = [&]() -> PSet { auto param_sinfo = TupleStructInfo(params.Map([](const auto& var) { return GetStructInfo(var); })); auto arr = DefinableTIRVarsInStructInfo(param_sinfo); return {arr.begin(), arr.end()}; }(); - // Use an array to define the order of the symbolic variables + // Collect any additional TIR variables that should be provided. + // The `DefinableTIRVarsInStructInfo` function returns the TIR + // variables in order of their occurrence, so the output is + // deterministic. Array free_tir_vars; - for (const auto& tir_var : FreeSymbolicVars(func->body)) { - if (!defined_tir_params.count(tir_var)) { + for (const auto& tir_var : DefinableTIRVarsInStructInfo(GetStructInfo(func))) { + if (required_tir_vars.count(tir_var) && !inferable_tir_params.count(tir_var)) { free_tir_vars.push_back(tir_var); } } diff --git a/tests/python/relax/test_transform_remove_symbolic_expression_in_subroutine.py b/tests/python/relax/test_transform_remove_symbolic_expression_in_subroutine.py new file mode 100644 index 000000000000..5bb0cf51d316 --- /dev/null +++ b/tests/python/relax/test_transform_remove_symbolic_expression_in_subroutine.py @@ -0,0 +1,386 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import ir as I, relax as R, tir as T + + +class Base: + def test_after_remove_symbolic_expression(self): + """Run RemoveSymbolicExpressionInSubroutine and compare""" + after = relax.transform.RemoveSymbolicExpressionInSubroutine()(self.before) + tvm.ir.assert_structural_equal(self.expected, after) + + def test_after_remove_unused(self): + """Run RemoveSymbolicExpressionInSubroutine, then remove unused + + The `RemoveSymbolicExpressionInSubroutine` transform is + designed to allow an expression to be inferred, where + previously the variables used within the expression were + explicitly provided. After + `RemoveSymbolicExpressionInSubroutine`, the arguments + providing the explicit definition are now unused, and can be + removed using `RemoveUnusedParameters`. + + """ + after = tvm.ir.transform.Sequential( + [ + relax.transform.RemoveSymbolicExpressionInSubroutine(), + relax.transform.RemoveUnusedParameters(), + ] + )(self.before) + tvm.ir.assert_structural_equal(self.expected_after_removing_unused, after) + + +class TestSimple(Base): + """Replace PrimExpr with a single tir.Var + + Here, the `batch_size` and `seq_len` variables are only used as + part of the expression `batch_size * seq_len`. While the + expression `batch_size * seq_len` must be propagated to the output + shape, neither `batch_size` nor `seq_len` is otherwise required. + """ + + @property + def before(self): + @I.ir_module + class Module: + @R.function + def main( + A: R.Tensor(["batch_size", "seq_len", "hidden_size"], "float16") + ) -> R.Tensor(["batch_size", "seq_len", "hidden_size"], "float16"): + batch_size = T.int64() + seq_len = T.int64() + hidden_size = T.int64() + A_flat = R.reshape(A, [-1, hidden_size]) + A_flat_norm = Module.rms_norm_impl(A_flat, R.shape([batch_size, seq_len])) + A_norm = R.reshape(A_flat_norm, [batch_size, seq_len, hidden_size]) + return A_norm + + @R.function(private=True) + def rms_norm_impl( + A: R.Tensor(["batch_size * seq_len", "hidden_size"], "float16"), + _: R.Shape(["batch_size", "seq_len"]), + ) -> R.Tensor(["batch_size * seq_len", "hidden_size"], "float16"): + A_squared = R.multiply(A, A) + A_mean_squared = R.mean(A_squared, axis=1, keepdims=True) + A_rms = R.sqrt(A_mean_squared) + A_norm = A / A_rms + return A_norm + + return Module + + @property + def expected(self): + @I.ir_module + class Module: + @R.function + def main( + A: R.Tensor(["batch_size", "seq_len", "hidden_size"], "float16") + ) -> R.Tensor(["batch_size", "seq_len", "hidden_size"], "float16"): + batch_size = T.int64() + seq_len = T.int64() + hidden_size = T.int64() + A_flat = R.reshape(A, [-1, hidden_size]) + A_flat_norm = Module.rms_norm_impl(A_flat, R.shape([batch_size, seq_len])) + A_norm = R.reshape(A_flat_norm, [batch_size, seq_len, hidden_size]) + return A_norm + + @R.function(private=True) + def rms_norm_impl( + A: R.Tensor(["A_dim0", "hidden_size"], "float16"), + _: R.Shape(["batch_size", "seq_len"]), + ) -> R.Tensor(["A_dim0", "hidden_size"], "float16"): + A_squared = R.multiply(A, A) + A_mean_squared = R.mean(A_squared, axis=1, keepdims=True) + A_rms = R.sqrt(A_mean_squared) + A_norm = A / A_rms + return A_norm + + return Module + + @property + def expected_after_removing_unused(self): + @I.ir_module + class Module: + @R.function + def main( + A: R.Tensor(["batch_size", "seq_len", "hidden_size"], "float16") + ) -> R.Tensor(["batch_size", "seq_len", "hidden_size"], "float16"): + batch_size = T.int64() + seq_len = T.int64() + hidden_size = T.int64() + A_flat = R.reshape(A, [-1, hidden_size]) + A_flat_norm = Module.rms_norm_impl(A_flat) + A_norm = R.reshape(A_flat_norm, [batch_size, seq_len, hidden_size]) + return A_norm + + @R.function(private=True) + def rms_norm_impl( + A: R.Tensor(["A_dim0", "hidden_size"], "float16"), + ) -> R.Tensor(["A_dim0", "hidden_size"], "float16"): + A_squared = R.multiply(A, A) + A_mean_squared = R.mean(A_squared, axis=1, keepdims=True) + A_rms = R.sqrt(A_mean_squared) + A_norm = A / A_rms + return A_norm + + return Module + + +class TestNoMutationOfExternallyExposedSubroutine(Base): + """No changes to public-facing functions + + Identical to `TestSimple`, except that the subroutine may be + called directly by a user. Therefore, its signature may not be + altered. + """ + + @property + def before(self): + @I.ir_module + class Module: + @R.function + def main( + A: R.Tensor(["batch_size", "seq_len", "hidden_size"], "float16") + ) -> R.Tensor(["batch_size", "seq_len", "hidden_size"], "float16"): + batch_size = T.int64() + seq_len = T.int64() + hidden_size = T.int64() + A_flat = R.reshape(A, [-1, hidden_size]) + A_flat_norm = Module.rms_norm_impl(A_flat, R.shape([batch_size, seq_len])) + A_norm = R.reshape(A_flat_norm, [batch_size, seq_len, hidden_size]) + return A_norm + + @R.function + def rms_norm_impl( + A: R.Tensor(["batch_size * seq_len", "hidden_size"], "float16"), + _: R.Shape(["batch_size", "seq_len"]), + ) -> R.Tensor(["batch_size * seq_len", "hidden_size"], "float16"): + A_squared = R.multiply(A, A) + A_mean_squared = R.mean(A_squared, axis=1, keepdims=True) + A_rms = R.sqrt(A_mean_squared) + A_norm = A / A_rms + return A_norm + + return Module + + expected = before + expected_after_removing_unused = before + + +class TestRemoveMultipleVariables(Base): + """Replace multiple expressions with tir.Var""" + + @property + def before(self): + @I.ir_module + class Module: + @R.function + def main( + A: R.Tensor(["n1", "n2", "n3", "n4", "n5"], "float16") + ) -> R.Tensor(["n1", "n2", "n4", "n5"], "float16"): + n1 = T.int64() + n2 = T.int64() + n3 = T.int64() + n4 = T.int64() + n5 = T.int64() + + A = R.reshape(A, [n1 * n2, n3, n4 * n5]) + A = Module.first_element(A, R.shape([n1, n2, n4, n5])) + A = R.reshape(A, [n1, n2, n4, n5]) + return A + + @R.function(private=True) + def first_element( + A: R.Tensor(["n1 * n2", "n3", "n4 * n5"], "float16"), + _: R.Shape(["n1", "n2", "n4", "n5"]), + ) -> R.Tensor(["n1 * n2", "n4 * n5"], "float16"): + A = R.strided_slice(A, axes=[1], begin=[0], end=[1]) + A = R.squeeze(A, axis=1) + return A + + return Module + + @property + def expected(self): + @I.ir_module + class Module: + @R.function + def main( + A: R.Tensor(["n1", "n2", "n3", "n4", "n5"], "float16") + ) -> R.Tensor(["n1", "n2", "n4", "n5"], "float16"): + n1 = T.int64() + n2 = T.int64() + n3 = T.int64() + n4 = T.int64() + n5 = T.int64() + + A = R.reshape(A, [n1 * n2, n3, n4 * n5]) + A = Module.first_element(A, R.shape([n1, n2, n4, n5])) + A = R.reshape(A, [n1, n2, n4, n5]) + return A + + @R.function(private=True) + def first_element( + A: R.Tensor(["n12", "n3", "n45"], "float16"), + _: R.Shape(["n1", "n2", "n4", "n5"]), + ) -> R.Tensor(["n12", "n45"], "float16"): + A = R.strided_slice(A, axes=[1], begin=[0], end=[1]) + A = R.squeeze(A, axis=1) + return A + + return Module + + @property + def expected_after_removing_unused(self): + @I.ir_module + class Module: + @R.function + def main( + A: R.Tensor(["n1", "n2", "n3", "n4", "n5"], "float16") + ) -> R.Tensor(["n1", "n2", "n4", "n5"], "float16"): + n1 = T.int64() + n2 = T.int64() + n3 = T.int64() + n4 = T.int64() + n5 = T.int64() + + A = R.reshape(A, [n1 * n2, n3, n4 * n5]) + A = Module.first_element(A) + A = R.reshape(A, [n1, n2, n4, n5]) + return A + + @R.function(private=True) + def first_element( + A: R.Tensor(["n12", "n3", "n45"], "float16"), + ) -> R.Tensor(["n12", "n45"], "float16"): + A = R.strided_slice(A, axes=[1], begin=[0], end=[1]) + A = R.squeeze(A, axis=1) + return A + + return Module + + +class TestNoReplacementIfVariableUsedInExpression(Base): + """Do not replace PrimExpr if tir.Var is required + + Here, the `batch_size` and `seq_len` variables are used in the + subroutine, as part of the `R.reshape` expression. The + `R.prim_value` arguments must be retained in order to define + `batch_size` and `seq_len` in the subroutine. + + """ + + @property + def before(self): + @I.ir_module + class Module: + @R.function + def main( + A: R.Tensor(["batch_size", "seq_len", "hidden_size"], "float16") + ) -> R.Tensor(["batch_size", "seq_len", "hidden_size"], "float16"): + batch_size = T.int64() + seq_len = T.int64() + hidden_size = T.int64() + A_flat = R.reshape(A, [-1, hidden_size]) + A_norm = Module.rms_norm_impl( + A_flat, R.prim_value(batch_size), R.prim_value(seq_len) + ) + + return A_norm + + @R.function(private=True) + def rms_norm_impl( + A: R.Tensor(["batch_size * seq_len", "hidden_size"], "float16"), + _1: R.Prim(value="batch_size"), + _2: R.Prim(value="seq_len"), + ) -> R.Tensor(["batch_size", "seq_len", "hidden_size"], "float16"): + batch_size = T.int64() + seq_len = T.int64() + hidden_size = T.int64() + + A_squared = R.multiply(A, A) + A_mean_squared = R.mean(A_squared, axis=1, keepdims=True) + A_rms = R.sqrt(A_mean_squared) + A_flat_norm = A / A_rms + A_norm = R.reshape(A_flat_norm, [batch_size, seq_len, hidden_size]) + return A_norm + + return Module + + expected = before + expected_after_removing_unused = before + + +class TestNoReplacementIfVariableUsedInMatchCast(Base): + """Do not replace PrimExpr if tir.Var is required + + Here, the `batch_size` and `seq_len` variables are used in the + subroutine, as part of the `R.match_cast` binding. The + `R.prim_value` arguments must be retained in order to define + `batch_size` and `seq_len` in the subroutine. + + """ + + @property + def before(self): + @I.ir_module + class Module: + @R.function + def main( + A: R.Tensor(["batch_size", "seq_len", "hidden_size"], "float16") + ) -> R.Tensor(["batch_size", "seq_len", "hidden_size"], "float16"): + batch_size = T.int64() + seq_len = T.int64() + hidden_size = T.int64() + A_flat = R.reshape(A, [-1, hidden_size]) + A_flat_norm = Module.rms_norm_impl( + A_flat, R.prim_value(batch_size), R.prim_value(seq_len) + ) + A_norm = R.reshape(A_flat_norm, [batch_size, seq_len, hidden_size]) + + return A_norm + + @R.function(private=True) + def rms_norm_impl( + A: R.Tensor(["batch_size * seq_len", "hidden_size"], "float16"), + _1: R.Prim(value="batch_size"), + _2: R.Prim(value="seq_len"), + ) -> R.Tensor(["batch_size * seq_len", "hidden_size"], "float16"): + batch_size = T.int64() + seq_len = T.int64() + hidden_size = T.int64() + + A_norm_ndim = R.call_pure_packed( + "some_packed_func_implementation", A, sinfo_args=[R.Tensor(ndim=3)] + ) + A_norm = R.match_cast(A_norm_ndim, R.Tensor([batch_size, seq_len, hidden_size])) + A_flat_norm = R.reshape(A_norm, [batch_size * seq_len, hidden_size]) + + return A_flat_norm + + return Module + + expected = before + expected_after_removing_unused = before + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform_remove_unused_parameters.py b/tests/python/relax/test_transform_remove_unused_parameters.py index ea905eb88283..f25e154b4db9 100644 --- a/tests/python/relax/test_transform_remove_unused_parameters.py +++ b/tests/python/relax/test_transform_remove_unused_parameters.py @@ -60,6 +60,9 @@ class TestReplaceSymbolicVariables(BaseCompare): its shape defines the symbolic variables `m` and `n`. When removing the `R.Tensor` argument, we may need to provide additional parameters to define the symbolic variables. + + The order of symbolic variables is determined by the order of + their first occurrence in the subroutine's signature. """ @I.ir_module @@ -80,12 +83,12 @@ class Expected: def main(A: R.Tensor(["m", "n"], "float32")) -> R.Tensor(["m", "n"], "float32"): m = T.int64() n = T.int64() - out: R.Tensor([m, n], "float32") = Expected.func(R.prim_value(n), R.prim_value(m)) + out: R.Tensor([m, n], "float32") = Expected.func(R.prim_value(m), R.prim_value(n)) return out @R.function(private=True) def func( - param_n: R.Prim(value="n"), param_m: R.Prim(value="m") + param_m: R.Prim(value="m"), param_n: R.Prim(value="n") ) -> R.Tensor(["m", "n"], "float32"): m = T.int64() n = T.int64()