diff --git a/CMakeLists.txt b/CMakeLists.txt index 3a47d99a185a..8b1f89a075f0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -311,6 +311,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/relax/distributed/*.cc src/relax/distributed/transform/*.cc src/relax/op/distributed/*.cc + src/relax/testing/*.cc ) tvm_file_glob(GLOB CODEGEN_SRCS diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 0a5192d6580f..9b339babb435 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -223,6 +223,7 @@ class BlockBuilder : public ObjectRef { * \brief Create a BlockBuilder. * * \param ctx_mod Optional before-transformation context module for rewriting. + * * \return The created BlockBuilder. * * \note When rewriting an existing IRModule, it is important to pass it in as @@ -231,6 +232,48 @@ class BlockBuilder : public ObjectRef { */ TVM_DLL static BlockBuilder Create(Optional ctx_mod); + /*! \brief A marker struct to disable FNormalize + * + * This struct is used as a marker to disable the use of FNormalize + * by this block builder. This should only be used for TVMScript + * parsing, which may require producing un-normalized Relax IR for + * testing purposes, and to ensure that round-trips are unchanged. + * + * The name is deliberately verbose to draw attention during a code + * review. The explicit default constructor prevents aggregate + * initialization, ensuring that the full name of the marker struct + * appears at the callsite. + * + * This constructor is marked as no-lint to allow a zero-parameter + * constructor to be marked as explicit. The constructor must be + * explicit in order to disable aggregate initialization in C++17. + * While C++20 disables aggregate initialization when a + * user-declared constructor is present, C++17 only disables + * aggregate initialization when a user-defined constructor is + * present. Therefore, we need to mark the zero-parameter + * constructor as explicit in order to prevent aggregate + * initialization, and to ensure that the name appears at all + * callsites. + */ + struct DisableOperatorSpecificNormalizationForTVMScript { + explicit DisableOperatorSpecificNormalizationForTVMScript() = default; // NOLINT(*) + }; + /*! + * \brief Create a BlockBuilder. + * + * \param ctx_mod Optional before-transformation context module for rewriting. + * + * \param tag An instance of DisableOperatorSpecificNormalizationForTVMScript + * + * \return The created BlockBuilder. + * + * \note When rewriting an existing IRModule, it is important to pass it in as + * ctx_mod so you can lookup the context functions for cross function + * call analysis. + */ + TVM_DLL static BlockBuilder Create(Optional ctx_mod, + DisableOperatorSpecificNormalizationForTVMScript tag); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode); }; diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index 64e5bd89a58c..b44c4582d82d 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -50,9 +50,30 @@ using FInferStructInfo = using FCallPacked = String; /*! - * \brief The function type of a legalization function, which takes a - * BlockBuilder and the Call to be legalized, and outputs the legalization - * result Expr. + * \brief The function type of a normalization function. + * + * A normalization function is used when a `relax::Call` may be + * expressed in multiple syntactically valid and semantically + * equivalent forms, to normalize to a single representation. + * + * \param bb The BlockBuilder context. + * + * \param call The call to be normalized. It is provided by-value, to + * avoid copies for the common case where the call is already normalized. + */ +using FNormalize = runtime::TypedPackedFunc; + +/*! \brief The function type of a legalization function. + * + * A legalization function is used to replace a `relax::Call` with + * more concrete implementations. For example, the operation + * `relax.op.add` may be replaced with a call to a TIR function + * implementing addition of two tensors. + * + * The purpose of `FLegalize` is to remove calls to the operator while + * lowering. Therefore, unlike `FNormalize`, the resulting expression + * may *not* contain the original operator. + * * \param bb The BlockBuilder context. * \param call The call to be legalized. */ diff --git a/python/tvm/relax/ir/instrument.py b/python/tvm/relax/ir/instrument.py index a297e3f15a56..1ecd87fe1b97 100644 --- a/python/tvm/relax/ir/instrument.py +++ b/python/tvm/relax/ir/instrument.py @@ -23,14 +23,29 @@ class WellFormedInstrument: """An instrument that checks the input/output IRModule of the Pass is well formed. It will skip specific passes, like Normalize. + + Parameters + ---------- + check_struct_info: bool + + If True, validate the struct info in the module. If False, + skip these checks. + + validate_before_transform: bool + + If True (default), perform a well-formed check before running + a transform. If False, only perform the well-formed check + after running a transform. """ - def __init__(self, check_struct_info=True): + def __init__(self, check_struct_info: bool = True, validate_before_transform: bool = True): self.skip_pass_name = ["Normalize", "ResolveGlobals"] self.check_struct_info = check_struct_info + self.validate_before_transform = validate_before_transform def run_before_pass(self, mod, pass_info): - self._check(mod, pass_info.name, "Before") + if self.validate_before_transform: + self._check(mod, pass_info.name, "Before") def run_after_pass(self, mod, pass_info): self._check(mod, pass_info.name, "After") diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index 6b58b087feb9..ccae38a138a3 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -17,6 +17,7 @@ # pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ """Relax transformation passes for testing""" +import tvm from tvm import ir, relax from tvm.ir import transform from tvm.ir.module import IRModule @@ -122,3 +123,8 @@ def transform(self): return new_mod return Lowerer().transform() + + +def ApplyEmptyCppMutator() -> tvm.ir.transform.Pass: + packed_func = tvm.get_global_func("relax.testing.transform.ApplyEmptyCppMutator") + return packed_func() diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index 79135b943ae6..5cb577e82bda 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -67,6 +67,7 @@ #include #include #include +#include #include #include #include @@ -282,17 +283,17 @@ class WellFormedChecker : public relax::ExprVisitor, } } - void VisitExpr_(const CallNode* op) final { - if (IsLeafOrTuple(op->op)) { + void VisitExpr_(const CallNode* call) final { + if (IsLeafOrTuple(call->op)) { const FunctionNode* prev_visited_func = cur_visited_func_; cur_visited_func_ = nullptr; // close the symbolic var dup check - this->VisitExpr(op->op); + this->VisitExpr(call->op); cur_visited_func_ = prev_visited_func; } else { - Malformed(Diagnostic::Error(op) << "The called expression must be a leaf expression"); + Malformed(Diagnostic::Error(call) << "The called expression must be a leaf expression"); } - for (size_t i = 0; i < op->args.size(); i++) { - Expr arg = op->args[i]; + for (size_t i = 0; i < call->args.size(); i++) { + Expr arg = call->args[i]; if (IsLeafOrTuple(arg)) { this->VisitExpr(arg); } else { @@ -301,13 +302,33 @@ class WellFormedChecker : public relax::ExprVisitor, } } - for (const StructInfo& sinfo_arg : op->sinfo_args) { + for (const StructInfo& sinfo_arg : call->sinfo_args) { this->VisitStructInfo(sinfo_arg); } - CheckStructInfo(op); - if (is_dataflow_ && check_struct_info_ && IsImpureCall(GetRef(op))) { - Malformed(Diagnostic::Error(op) << "There cannot be an impure call inside a dataflow block."); + CheckStructInfo(call); + if (is_dataflow_ && check_struct_info_ && IsImpureCall(GetRef(call))) { + Malformed(Diagnostic::Error(call) + << "There cannot be an impure call inside a dataflow block."); + } + + // If the operation has defined a custom normalization function + // using the FNormalize attribute, the call node must be normalized in order to be well-formed. + // If we apply the FNormalize and it produces any change, modified the expression, re-visit in + // case it produced a nested expression. + + if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr) { + auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_); + auto before_normalize = GetRef(call); + auto after_normalize = func_normalize(dummy_builder, before_normalize); + if (!before_normalize.same_as(after_normalize)) { + Malformed( + Diagnostic::Error(call) + << "If an operator defines an operator-specific normalization function (FNormalize), " + << "calls to that operator must be normalized with it. " + << "However, normalization of " << before_normalize << " resulted in " + << after_normalize); + } } } @@ -538,6 +559,8 @@ class WellFormedChecker : public relax::ExprVisitor, std::unordered_map param_var_func_map_; std::unordered_map symbolic_var_func_map_; + + tvm::OpAttrMap op_map_normalize_ = Op::GetAttrMap("FNormalize"); }; bool WellFormed(IRModule m, bool check_struct_info) { diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 5037161fcb90..f58ea20223f1 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -463,6 +463,9 @@ class BlockBuilderImpl : public BlockBuilderNode { class Normalizer : public BlockBuilderImpl, private ExprFunctor { public: explicit Normalizer(IRModule context_mod) : BlockBuilderImpl(context_mod) {} + explicit Normalizer(IRModule context_mod, + BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript) + : BlockBuilderImpl(context_mod), apply_f_normalize_(false) {} Expr Normalize(const Expr& expr) final { Expr normalized = this->VisitExpr(expr); @@ -578,18 +581,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorNormalizeArgument(op->op); - bool unchanged = new_op.same_as(op->op); - Array new_args; - - for (Expr arg : op->args) { - Expr new_arg = this->NormalizeArgument(arg); - new_args.push_back(new_arg); - unchanged &= new_arg.same_as(arg); - } + Array new_args = op->args.Map([this](const Expr& arg) { return NormalizeArgument(arg); }); Call call; - if (unchanged) { + if (new_op.same_as(op->op) && new_args.same_as(op->args)) { call = GetRef(op); } else { call = Call(new_op, new_args, op->attrs, op->sinfo_args); @@ -600,6 +596,19 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorop, nullptr); func_normalize != nullptr) { + Expr normalized = func_normalize(GetRef(this), call); + if (!normalized.same_as(call)) { + return VisitExpr(normalized); + } + } + } + return call; } @@ -917,6 +926,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor("FInferStructInfo"); tvm::OpAttrMap op_map_dist_infer_struct_info_ = Op::GetAttrMap("dist.FInferStructInfo"); + /*! \brief Operator normalization function */ + tvm::OpAttrMap op_map_normalize_ = Op::GetAttrMap("FNormalize"); + + /*! \brief Whether the FNormalize function should be applied */ + bool apply_f_normalize_{true}; }; BlockBuilder BlockBuilder::Create(Optional mod) { @@ -924,6 +938,13 @@ BlockBuilder BlockBuilder::Create(Optional mod) { return BlockBuilder(n); } +BlockBuilder BlockBuilder::Create(Optional mod, + BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript) { + ObjectPtr n = make_object( + mod.value_or(IRModule()), BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript()); + return BlockBuilder(n); +} + //--------------------------------------- // User facing function registration. //--------------------------------------- diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 01d0d04be0cc..fe74286a51f1 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -253,6 +253,14 @@ StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { return call->sinfo_args[0]; } +Expr NormalizeCallTIR(const BlockBuilder&, Call call) { + // Temporary implementation to ensure that at least one op has a + // registered value for FNormalize. This temporary implementation + // is fully implemented in follow-up PR + // https://github.com/apache/tvm/pull/16068. + return std::move(call); +} + RELAY_REGISTER_OP("relax.call_tir") .set_num_inputs(3) .add_argument("func", "Expr", "The destination-passing-style function.") @@ -261,6 +269,7 @@ RELAY_REGISTER_OP("relax.call_tir") "ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from " "args if unused") .set_attr("FInferStructInfo", InferStructInfoCallTIR) + .set_attr("FNormalize", NormalizeCallTIR) .set_attr("FPurity", Bool(true)); Expr MakeCallTIR(Expr func, Tuple args, Array out_sinfo_list, diff --git a/src/relax/testing/transform.cc b/src/relax/testing/transform.cc new file mode 100644 index 000000000000..eed2329e3d3a --- /dev/null +++ b/src/relax/testing/transform.cc @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +namespace tvm { +namespace relax { +namespace testing { + +class EmptyCppMutator : public relax::ExprMutator {}; + +tvm::transform::Pass ApplyEmptyCppMutator() { + auto pass_func = [](Function func, IRModule, tvm::transform::PassContext) -> Function { + EmptyCppMutator mutator; + return Downcast(mutator.VisitExpr(std::move(func))); + }; + return tvm::relax::transform::CreateFunctionPass(pass_func, 0, + "relax.testing.ApplyEmptyCppMutator", {}); +} + +TVM_REGISTER_GLOBAL("relax.testing.transform.ApplyEmptyCppMutator") + .set_body_typed(ApplyEmptyCppMutator); + +} // namespace testing +} // namespace relax +} // namespace tvm diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 9af52fa80bd4..285a3a348e3b 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -59,7 +59,8 @@ FunctionFrame Function(const Bool& is_pure, const Bool& is_private) { if (const Optional mod_frame = ir_builder->GetLastFrame()) { mod = tvm::IRModule(mod_frame.value()->functions); } - n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/mod); + n->block_builder = tvm::relax::BlockBuilder::Create( + /*mod=*/mod, tvm::relax::BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript()); n->is_pure = is_pure; n->is_private = is_private; return FunctionFrame(n); diff --git a/tests/python/relax/conftest.py b/tests/python/relax/conftest.py index f1b1187066e6..1e12a95e524b 100644 --- a/tests/python/relax/conftest.py +++ b/tests/python/relax/conftest.py @@ -20,4 +20,58 @@ from tvm.relax.ir.instrument import WellFormedInstrument -tvm.transform.PassContext.current().override_instruments([WellFormedInstrument()]) +@pytest.fixture +def unit_test_marks(request): + """Get all marks applied to a test + + From https://stackoverflow.com/a/61379477. + """ + marks = [m.name for m in request.node.iter_markers()] + if request.node.parent: + marks += [m.name for m in request.node.parent.iter_markers()] + yield marks + + +def pytest_configure(config): + config.addinivalue_line( + "markers", + ( + "skip_well_formed_check_before_transform: " + "Only check for well-formed IRModule after a transform" + ), + ) + + +# By default, apply the well-formed check before and after all +# transforms. Checking well-formed-ness after the transform ensures +# that all transforms produce well-formed output. Checking +# well-formed-ness before the transform ensures that test cases +# (usually hand-written) are providing well-formed inputs. +# +# This is provided as a test fixture so that it can be overridden for +# specific tests. If a test must provide ill-formed input to a +# transform, it can be marked with +# `@pytest.mark.skip_well_formed_check_before_transform` +@pytest.fixture(autouse=True) +def apply_instrument_well_formed(unit_test_marks): + + validate_before_transform = "skip_well_formed_check_before_transform" not in unit_test_marks + + instrument = WellFormedInstrument(validate_before_transform=validate_before_transform) + current = tvm.transform.PassContext.current() + + override = tvm.transform.PassContext( + # Append the new instrument + instruments=[*current.instruments, instrument], + # Forward all other parameters + opt_level=current.opt_level, + required_pass=current.required_pass, + disabled_pass=current.disabled_pass, + config=current.config, + trace_stack=current.trace_stack, + make_traceable=current.make_traceable, + num_evals=current.num_evals, + tuning_api_database=current.get_tuning_api_database(), + ) + with override: + yield diff --git a/tests/python/relax/test_transform_operator_specific_normalization.py b/tests/python/relax/test_transform_operator_specific_normalization.py new file mode 100644 index 000000000000..07d541ab1ed5 --- /dev/null +++ b/tests/python/relax/test_transform_operator_specific_normalization.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test FNormalize usage""" + +import tvm +import tvm.testing +import tvm.relax.testing.transform + +from tvm import relax +from tvm.script.parser import ir as I, relax as R + +import pytest + +define_normalization = tvm.testing.parameter(True) + + +@pytest.fixture +def custom_op(define_normalization): + """A custom operator for testing purposes + + The custom operator ignores its second argument. If there isn't a + second argument which can be ignored, FNormalize appends an + additional argument so that it can be properly ignored. + """ + + op_name = "custom_op.ignore_second_argument" + + def infer_struct_info(call: relax.Call, context: relax.BlockBuilder): + return call.args[0].struct_info + + def normalize(context: relax.BlockBuilder, call: relax.Call): + if len(call.args) == 1: + return relax.Call(call.op, [call.args[0], relax.Tuple([])]) + else: + return call + + def legalize(context: relax.BlockBuilder, call: relax.Call): + return call.args[0] + + op_attrs = { + "FInferStructInfo": infer_struct_info, + "FLegalize": legalize, + "FPurity": True, + } + if define_normalization: + op_attrs["FNormalize"] = normalize + + for key, value in op_attrs.items(): + tvm.ir.register_op_attr(op_name, key, value) + + op = tvm.ir.Op.get(op_name) + yield op + + for key in op_attrs: + op.reset_attr(key) + + +def test_normalization_suppressed_for_tvmscript(custom_op): + """FNormalize isn't applied when parsing TVMScript + + TVMScript should be able to produce un-normalized Relax IR for + specifying test cases, and to ensure that no changes occur when + performing a round-trip through TVMScript. + """ + + @R.function + def func(A: R.Tensor): + return relax.Call(custom_op, [A]) + + call_expr = func.body.blocks[0].bindings[0].value + assert isinstance( + call_expr, relax.Call + ), "Test implementation error, didn't extract the correct expression" + assert ( + len(call_expr.args) == 1 + ), "Expected TVMScript to suppress use of FNormalize, produce arguments as written" + + +@pytest.mark.skip_well_formed_check_before_transform +def test_normalization_applied_during_cpp_mutator(custom_op): + """FNormalize is applied by relax::ExprMutator subclasses""" + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor): + return relax.Call(custom_op, [A]) + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor): + return relax.Call(custom_op, [A, R.tuple()]) + + After = tvm.relax.testing.transform.ApplyEmptyCppMutator()(Before) + + assert not tvm.ir.structural_equal(Before, After) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_normalization_applied_during_python_mutator(custom_op): + """FNormalize is applied by relax.ExprMutator subclasses""" + + @R.function(private=True) + def before(A: R.Tensor): + return relax.Call(custom_op, [A]) + + @R.function(private=True) + def expected(A: R.Tensor): + return relax.Call(custom_op, [A, R.tuple()]) + + @relax.expr_functor.mutator + class EmptyPyExprMutator(relax.PyExprMutator): + """Default ExprMutator""" + + after = EmptyPyExprMutator().visit_expr(before) + + assert not tvm.ir.structural_equal(before, after) + tvm.ir.assert_structural_equal(expected, after) + + +def test_normalized_call_node_is_well_formed(custom_op): + """If FNormalize wouldn't apply a change, the IR is well-formed""" + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor): + return relax.Call(custom_op, [A, A]) + + assert relax.analysis.well_formed(Module) + + +@pytest.mark.skip_well_formed_check_before_transform +@pytest.mark.parametrize("define_normalization", [True, False]) +def test_un_normalized_call_node_is_ill_formed(custom_op, define_normalization): + """If FNormalize would apply a change, the IR is ill-formed + + This only applies if FNormalize exists. An operator without + FNormalize has no corresponding check applied. + """ + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor): + return relax.Call(custom_op, [A]) + + if define_normalization: + assert not relax.analysis.well_formed(Module) + else: + assert relax.analysis.well_formed(Module) + + +if __name__ == "__main__": + tvm.testing.main()