Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
src/relax/distributed/*.cc
src/relax/distributed/transform/*.cc
src/relax/op/distributed/*.cc
src/relax/testing/*.cc
)

tvm_file_glob(GLOB CODEGEN_SRCS
Expand Down
43 changes: 43 additions & 0 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class BlockBuilder : public ObjectRef {
* \brief Create a BlockBuilder.
*
* \param ctx_mod Optional before-transformation context module for rewriting.
*
* \return The created BlockBuilder.
*
* \note When rewriting an existing IRModule, it is important to pass it in as
Expand All @@ -231,6 +232,48 @@ class BlockBuilder : public ObjectRef {
*/
TVM_DLL static BlockBuilder Create(Optional<IRModule> ctx_mod);

/*! \brief A marker struct to disable FNormalize
*
* This struct is used as a marker to disable the use of FNormalize
* by this block builder. This should only be used for TVMScript
* parsing, which may require producing un-normalized Relax IR for
* testing purposes, and to ensure that round-trips are unchanged.
*
* The name is deliberately verbose to draw attention during a code
* review. The explicit default constructor prevents aggregate
* initialization, ensuring that the full name of the marker struct
* appears at the callsite.
*
* This constructor is marked as no-lint to allow a zero-parameter
* constructor to be marked as explicit. The constructor must be
* explicit in order to disable aggregate initialization in C++17.
* While C++20 disables aggregate initialization when a
* user-declared constructor is present, C++17 only disables
* aggregate initialization when a user-defined constructor is
* present. Therefore, we need to mark the zero-parameter
* constructor as explicit in order to prevent aggregate
* initialization, and to ensure that the name appears at all
* callsites.
*/
struct DisableOperatorSpecificNormalizationForTVMScript {
explicit DisableOperatorSpecificNormalizationForTVMScript() = default; // NOLINT(*)
};
/*!
* \brief Create a BlockBuilder.
*
* \param ctx_mod Optional before-transformation context module for rewriting.
*
* \param tag An instance of DisableOperatorSpecificNormalizationForTVMScript
*
* \return The created BlockBuilder.
*
* \note When rewriting an existing IRModule, it is important to pass it in as
* ctx_mod so you can lookup the context functions for cross function
* call analysis.
*/
TVM_DLL static BlockBuilder Create(Optional<IRModule> ctx_mod,
DisableOperatorSpecificNormalizationForTVMScript tag);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BlockBuilder, ObjectRef, BlockBuilderNode);
};

Expand Down
27 changes: 24 additions & 3 deletions include/tvm/relax/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,30 @@ using FInferStructInfo =
using FCallPacked = String;

/*!
* \brief The function type of a legalization function, which takes a
* BlockBuilder and the Call to be legalized, and outputs the legalization
* result Expr.
* \brief The function type of a normalization function.
*
* A normalization function is used when a `relax::Call` may be
* expressed in multiple syntactically valid and semantically
* equivalent forms, to normalize to a single representation.
*
* \param bb The BlockBuilder context.
*
* \param call The call to be normalized. It is provided by-value, to
* avoid copies for the common case where the call is already normalized.
*/
using FNormalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, Call call)>;

/*! \brief The function type of a legalization function.
*
* A legalization function is used to replace a `relax::Call` with
* more concrete implementations. For example, the operation
* `relax.op.add` may be replaced with a call to a TIR function
* implementing addition of two tensors.
*
* The purpose of `FLegalize` is to remove calls to the operator while
* lowering. Therefore, unlike `FNormalize`, the resulting expression
* may *not* contain the original operator.
*
* \param bb The BlockBuilder context.
* \param call The call to be legalized.
*/
Expand Down
19 changes: 17 additions & 2 deletions python/tvm/relax/ir/instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,29 @@
class WellFormedInstrument:
"""An instrument that checks the input/output IRModule of the Pass
is well formed. It will skip specific passes, like Normalize.

Parameters
----------
check_struct_info: bool

If True, validate the struct info in the module. If False,
skip these checks.

validate_before_transform: bool

If True (default), perform a well-formed check before running
a transform. If False, only perform the well-formed check
after running a transform.
"""

def __init__(self, check_struct_info=True):
def __init__(self, check_struct_info: bool = True, validate_before_transform: bool = True):
self.skip_pass_name = ["Normalize", "ResolveGlobals"]
self.check_struct_info = check_struct_info
self.validate_before_transform = validate_before_transform

def run_before_pass(self, mod, pass_info):
self._check(mod, pass_info.name, "Before")
if self.validate_before_transform:
self._check(mod, pass_info.name, "Before")

def run_after_pass(self, mod, pass_info):
self._check(mod, pass_info.name, "After")
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relax/testing/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=unused-argument, invalid-name, no-else-return, abstract-method, arguments-differ
"""Relax transformation passes for testing"""

import tvm
from tvm import ir, relax
from tvm.ir import transform
from tvm.ir.module import IRModule
Expand Down Expand Up @@ -122,3 +123,8 @@ def transform(self):
return new_mod

return Lowerer().transform()


def ApplyEmptyCppMutator() -> tvm.ir.transform.Pass:
packed_func = tvm.get_global_func("relax.testing.transform.ApplyEmptyCppMutator")
return packed_func()
43 changes: 33 additions & 10 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/struct_info_functor.h>
#include <tvm/relax/utils.h>
#include <tvm/tir/expr_functor.h>
Expand Down Expand Up @@ -282,17 +283,17 @@ class WellFormedChecker : public relax::ExprVisitor,
}
}

void VisitExpr_(const CallNode* op) final {
if (IsLeafOrTuple(op->op)) {
void VisitExpr_(const CallNode* call) final {
if (IsLeafOrTuple(call->op)) {
const FunctionNode* prev_visited_func = cur_visited_func_;
cur_visited_func_ = nullptr; // close the symbolic var dup check
this->VisitExpr(op->op);
this->VisitExpr(call->op);
cur_visited_func_ = prev_visited_func;
} else {
Malformed(Diagnostic::Error(op) << "The called expression must be a leaf expression");
Malformed(Diagnostic::Error(call) << "The called expression must be a leaf expression");
}
for (size_t i = 0; i < op->args.size(); i++) {
Expr arg = op->args[i];
for (size_t i = 0; i < call->args.size(); i++) {
Expr arg = call->args[i];
if (IsLeafOrTuple(arg)) {
this->VisitExpr(arg);
} else {
Expand All @@ -301,13 +302,33 @@ class WellFormedChecker : public relax::ExprVisitor,
}
}

for (const StructInfo& sinfo_arg : op->sinfo_args) {
for (const StructInfo& sinfo_arg : call->sinfo_args) {
this->VisitStructInfo(sinfo_arg);
}

CheckStructInfo(op);
if (is_dataflow_ && check_struct_info_ && IsImpureCall(GetRef<Call>(op))) {
Malformed(Diagnostic::Error(op) << "There cannot be an impure call inside a dataflow block.");
CheckStructInfo(call);
if (is_dataflow_ && check_struct_info_ && IsImpureCall(GetRef<Call>(call))) {
Malformed(Diagnostic::Error(call)
<< "There cannot be an impure call inside a dataflow block.");
}

// If the operation has defined a custom normalization function
// using the FNormalize attribute, the call node must be normalized in order to be well-formed.
// If we apply the FNormalize and it produces any change, modified the expression, re-visit in
// case it produced a nested expression.

if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr) {
auto dummy_builder = tvm::relax::BlockBuilder::Create(mod_);
auto before_normalize = GetRef<Call>(call);
auto after_normalize = func_normalize(dummy_builder, before_normalize);
if (!before_normalize.same_as(after_normalize)) {
Malformed(
Diagnostic::Error(call)
<< "If an operator defines an operator-specific normalization function (FNormalize), "
<< "calls to that operator must be normalized with it. "
<< "However, normalization of " << before_normalize << " resulted in "
<< after_normalize);
}
}
}

Expand Down Expand Up @@ -538,6 +559,8 @@ class WellFormedChecker : public relax::ExprVisitor,
std::unordered_map<Var, const FunctionNode*, ObjectPtrHash, ObjectPtrEqual> param_var_func_map_;
std::unordered_map<tir::Var, const FunctionNode*, ObjectPtrHash, ObjectPtrEqual>
symbolic_var_func_map_;

tvm::OpAttrMap<FNormalize> op_map_normalize_ = Op::GetAttrMap<FNormalize>("FNormalize");
};

bool WellFormed(IRModule m, bool check_struct_info) {
Expand Down
39 changes: 30 additions & 9 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ class BlockBuilderImpl : public BlockBuilderNode {
class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&)> {
public:
explicit Normalizer(IRModule context_mod) : BlockBuilderImpl(context_mod) {}
explicit Normalizer(IRModule context_mod,
BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript)
: BlockBuilderImpl(context_mod), apply_f_normalize_(false) {}

Expr Normalize(const Expr& expr) final {
Expr normalized = this->VisitExpr(expr);
Expand Down Expand Up @@ -578,18 +581,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&

Expr VisitExpr_(const CallNode* op) final {
Expr new_op = this->NormalizeArgument(op->op);
bool unchanged = new_op.same_as(op->op);

Array<Expr> new_args;

for (Expr arg : op->args) {
Expr new_arg = this->NormalizeArgument(arg);
new_args.push_back(new_arg);
unchanged &= new_arg.same_as(arg);
}
Array<Expr> new_args = op->args.Map([this](const Expr& arg) { return NormalizeArgument(arg); });

Call call;
if (unchanged) {
if (new_op.same_as(op->op) && new_args.same_as(op->args)) {
call = GetRef<Call>(op);
} else {
call = Call(new_op, new_args, op->attrs, op->sinfo_args);
Expand All @@ -600,6 +596,19 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&
UpdateStructInfo(call, inferred_sinfo);
}

// If the operation has defined a custom normalization
// function using the FNormalize attribute, apply it. If the
// normalization modified the expression, re-visit in case it
// produced a nested expression.
if (apply_f_normalize_) {
if (auto func_normalize = op_map_normalize_.get(op->op, nullptr); func_normalize != nullptr) {
Expr normalized = func_normalize(GetRef<BlockBuilder>(this), call);
if (!normalized.same_as(call)) {
return VisitExpr(normalized);
}
}
}

return call;
}

Expand Down Expand Up @@ -917,13 +926,25 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&
Op::GetAttrMap<FInferStructInfo>("FInferStructInfo");
tvm::OpAttrMap<FInferStructInfo> op_map_dist_infer_struct_info_ =
Op::GetAttrMap<FInferStructInfo>("dist.FInferStructInfo");
/*! \brief Operator normalization function */
tvm::OpAttrMap<FNormalize> op_map_normalize_ = Op::GetAttrMap<FNormalize>("FNormalize");

/*! \brief Whether the FNormalize function should be applied */
bool apply_f_normalize_{true};
};

BlockBuilder BlockBuilder::Create(Optional<IRModule> mod) {
ObjectPtr<BlockBuilderNode> n = make_object<Normalizer>(mod.value_or(IRModule()));
return BlockBuilder(n);
}

BlockBuilder BlockBuilder::Create(Optional<IRModule> mod,
BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript) {
ObjectPtr<BlockBuilderNode> n = make_object<Normalizer>(
mod.value_or(IRModule()), BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript());
return BlockBuilder(n);
}

//---------------------------------------
// User facing function registration.
//---------------------------------------
Expand Down
9 changes: 9 additions & 0 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,14 @@ StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) {
return call->sinfo_args[0];
}

Expr NormalizeCallTIR(const BlockBuilder&, Call call) {
// Temporary implementation to ensure that at least one op has a
// registered value for FNormalize. This temporary implementation
// is fully implemented in follow-up PR
// https://github.com/apache/tvm/pull/16068.
return std::move(call);
}

RELAY_REGISTER_OP("relax.call_tir")
.set_num_inputs(3)
.add_argument("func", "Expr", "The destination-passing-style function.")
Expand All @@ -261,6 +269,7 @@ RELAY_REGISTER_OP("relax.call_tir")
"ShapeExpr representing a tuple of ints to unpack during runtime. Omitted from "
"args if unused")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoCallTIR)
.set_attr<FNormalize>("FNormalize", NormalizeCallTIR)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeCallTIR(Expr func, Tuple args, Array<TensorStructInfo> out_sinfo_list,
Expand Down
43 changes: 43 additions & 0 deletions src/relax/testing/transform.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>

namespace tvm {
namespace relax {
namespace testing {

class EmptyCppMutator : public relax::ExprMutator {};

tvm::transform::Pass ApplyEmptyCppMutator() {
auto pass_func = [](Function func, IRModule, tvm::transform::PassContext) -> Function {
EmptyCppMutator mutator;
return Downcast<Function>(mutator.VisitExpr(std::move(func)));
};
return tvm::relax::transform::CreateFunctionPass(pass_func, 0,
"relax.testing.ApplyEmptyCppMutator", {});
}

TVM_REGISTER_GLOBAL("relax.testing.transform.ApplyEmptyCppMutator")
.set_body_typed(ApplyEmptyCppMutator);

} // namespace testing
} // namespace relax
} // namespace tvm
3 changes: 2 additions & 1 deletion src/script/ir_builder/relax/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ FunctionFrame Function(const Bool& is_pure, const Bool& is_private) {
if (const Optional<ir::IRModuleFrame> mod_frame = ir_builder->GetLastFrame<ir::IRModuleFrame>()) {
mod = tvm::IRModule(mod_frame.value()->functions);
}
n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/mod);
n->block_builder = tvm::relax::BlockBuilder::Create(
/*mod=*/mod, tvm::relax::BlockBuilder::DisableOperatorSpecificNormalizationForTVMScript());
n->is_pure = is_pure;
n->is_private = is_private;
return FunctionFrame(n);
Expand Down
Loading