Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,26 @@ TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr, const Array<Var>& defs);

/*!
* \brief Analyze the side effect
* \brief Analyze the side effect of an expression
* \param expr The expression to be checked.
*
* \return CallEffectKind, can be kPure, kReadState or kUpdateState
*/
TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);

/*!
* \brief Analyze the side effect of a function
*
* \param func The expression to be checked.
*
* \param assert_on_error If true, an error will be thrown for an
* impure function. If false (default), the purity of the PrimFunc
* will be returned.
*
* \return The purity of the function
*/
TVM_DLL bool IsPureFunction(const PrimFunc& func, bool assert_on_error = false);

/*!
* \brief Whether the given Stmt uses any var in the given variable set.
* \param stmt The Stmt to be checked.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, msg):
register_error("AttributeError", AttributeError)
register_error("KeyError", KeyError)
register_error("IndexError", IndexError)
register_error("AssertionError", AssertionError)


@register_error
Expand Down
44 changes: 28 additions & 16 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,19 +503,26 @@ def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Ob
f"The format string argument to assert must be a string, given {type(format_str)})"
)

# should be guaranteed by the type system
if not isinstance(condition, tvm.nd.NDArray):
raise ValueError(f"The condition must be an NDArray, but given a {type(condition)}.")

# may happen if the original program had unknown shape or dtype for the tensor's type
dtype = condition.dtype
if dtype != "bool":
raise ValueError(f"The condition must be a bool scalar, but given a {dtype} tensor")
shape = condition.shape
if len(shape) != 0:
raise ValueError(f"The condition must be a scalar, but it has a shape of {shape}")

val = condition.numpy()
if isinstance(condition, (bool, int)):
val = condition
elif isinstance(condition, tvm.nd.NDArray):
# may happen if the original program had unknown shape or dtype for the tensor's type
dtype = condition.dtype
if dtype != "bool":
raise ValueError(f"The condition must be a bool scalar, but given a {dtype} tensor")
shape = condition.shape
if len(shape) != 0:
raise ValueError(f"The condition must be a scalar, but it has a shape of {shape}")

val = condition.numpy()

else:
# should be guaranteed by the type system
raise ValueError(
f"The condition for relax assert must be a bool, int, or NDArray, "
f"but received a {type(condition)}."
)

if not val:
error_message = "Assertion Failed"
if format_args or format_str != "":
Expand All @@ -528,7 +535,7 @@ def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Ob


def assert_op(
condition: Expr,
condition: Union[Expr, PrimExpr],
format_args: Optional[Union[Expr, List[Expr]]] = None,
format: Union[str, Expr] = "",
) -> Expr:
Expand All @@ -538,7 +545,7 @@ def assert_op(

Parameters
----------
condition: Expr
condition: Union[Expr, PrimExpr]
The assertion condition.

format_args: Optional[Union[Expr, List[Expr]]]
Expand All @@ -552,12 +559,17 @@ def assert_op(
result : Expr
A Call to the Relax assert operation.
"""
if not isinstance(condition, Expr):
condition = tvm.relax.PrimValue(condition)

if format_args is None:
format_args = []
if isinstance(format_args, Expr): # type: ignore
elif isinstance(format_args, Expr):
format_args = [format_args]

if isinstance(format, str):
format = StringImm(format)

return _ffi_api.assert_op(condition, format_args, format) # type: ignore


Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
transform.LowerAllocTensor(),
transform.KillAfterLastUse(),
transform.VMBuiltinLower(),
transform.ComputePrimValue(),
transform.VMShapeLower(),
transform.AttachGlobalSymbol(),
],
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CallTIRRewrite,
CanonicalizeBindings,
CombineParallelMatmul,
ComputePrimValue,
ConvertLayout,
ConvertToDataflow,
DataflowBlockPass,
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,25 @@ def KillAfterLastUse() -> tvm.ir.transform.Pass:
return _ffi_api.KillAfterLastUse() # type: ignore


def ComputePrimValue() -> tvm.ir.transform.Pass:
"""Compute all R.prim_value instances
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this description should be more precise. I assume it's supposed to come late in the phase ordering since it inserts direct calls to PrimFuncs? (And so should probably come after we end purity checking?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point on improving the docstring.

Regarding phase ordering, I don’t think we need to restrict its usage. The calls to PrimFunc instances are valid in user-provided Relax functions, so this could occur early in the phase ordering. The only limitation is that it must occur before VMShapeLower, as VMShapeLower expects all R.prim_value(arg) expressions to have int64 arguments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but inserting the PrimFunc calls will likely change the purity of the functions where that happens. call_tir could be used to avoid that but then that will require using this before lowering call_tir.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I think this is another argument in favor of allowing FuncStructInfo annotations for PrimFunc objects, as that would allow the generated PrimFunc instances to be marked as pure functions. I'll add a unit test to see how well that works for maintaining purity tracking when calling a pure PrimFunc from a pure Relax function.

Copy link
Contributor

@slyubomirsky slyubomirsky Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are they truly pure? No modification of external values?

Edit: Yeah, they just use a return value. I imagine this means that we actually have to check the bodies of PrimFuncs to determine if they're pure and also give users the option to override the automatic judgment. The rules for that can be very simple: Consider it impure if there is any write to a tensor or external call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking an even simpler heuristic: All PrimFuncs are impure, unless explicitly annotated otherwise. In this case, since the functions are being generated in a manner that requires purity, it could also provide the annotation.

For long-term, agreed, it would be good to have the TIR-level purity analysis. I think I'd weaken the condition you mentioned slightly: A function is impure if it writes to a buffer that it didn't itself allocate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'm fine with requiring an annotation since the great majority of PrimFuncs are going to be impure.


While high-level relax can include expressions in terms of its
symbolic variables, these expressions cannot natively be computed
within relax. In order to provide values for symbolic expressions
(e.g. `R.prim_value(N*N)`, where `N` is a symbolic variable), this
pass generates a PrimFunc in which the expression can be computed.
The relax graph is then updated to include a call to that
PrimFunc, in place of the original `R.prim_value(expr)`.

Returns
-------
ret : tvm.ir.transform.Pass

"""
return _ffi_api.ComputePrimValue() # type: ignore


def VMBuiltinLower() -> tvm.ir.transform.Pass:
"""Lowering generic intrinsic to VM intrinsics.

Expand Down
15 changes: 11 additions & 4 deletions python/tvm/script/ir_builder/relax/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,18 +511,25 @@ def SeqExpr() -> frame.SeqExprFrame: # pylint: disable=invalid-name
############################# If Then Else #############################


def If(condition: Expr) -> frame.IfFrame: # pylint: disable=invalid-name
def If(condition: Union[Expr, PrimExpr]) -> frame.IfFrame: # pylint: disable=invalid-name
"""Create an if frame.

Parameters
----------
condition : Expr
The condition of if statement, executes the true branch if the condition is true,
otherwise jump into the false branch.
condition : Union[Expr, PrimExpr]

The condition of if statement, executes the true branch if the
condition is true, otherwise jump into the false branch.

Returns
-------
res : frame.IfFrame
The result IfFrame.

"""
if not isinstance(condition, Expr):
condition = relax.PrimValue(condition)

return _ffi_api.If(condition) # type: ignore[attr-defined] # pylint: disable=no-member


Expand Down
33 changes: 26 additions & 7 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,12 +537,31 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar
The doc AST return node.
"""

ret_type = None
if node.returns is not None:
ret_type = self.eval_expr(node.returns)
if callable(ret_type):
ret_type = PrimType(ret_type().dtype)
supplied_annotation = self.function_annotations
func_annotation = supplied_annotation.get(node.name, {})

# Only ret_type is needed for func_signature.
func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type)
ret_type = None
with self.var_table.with_frame():
if node.returns is not None:
ret_type = self.eval_expr(node.returns)
if callable(ret_type):
ret_type = PrimType(ret_type().dtype)

arg_annotations = []
for arg in node.args.args:
if arg.annotation is None:
self.report_error(arg, "Type annotation required for function parameters.")
try:
ann = self.eval_expr(arg.annotation)
if callable(ann):
ann = ann()
except Exception: # pylint: disable=broad-except
ann = func_annotation.get(arg.arg, None)
if ann is None:
raise

IRBuilder.name(arg.arg, ann)
arg_annotations.append(ann)

func_signature = tvm.tir.PrimFunc(arg_annotations, None, ret_type=ret_type)
return I.decl_function(node.name, func_signature)
10 changes: 10 additions & 0 deletions python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,13 @@ def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]:
returns list of passes
"""
return _ffi_api.get_vtcm_compaction_passes() # type: ignore # pylint: disable=no-member


def is_pure_function(func: PrimFunc) -> bool:
"""Checks if the function is a pure function"""
return _ffi_api.is_pure_function(func, False) # type: ignore # pylint: disable=no-member


def assert_pure_function(func: PrimFunc) -> bool:
"""Asserts that the function is a pure function"""
return _ffi_api.is_pure_function(func, True) # type: ignore # pylint: disable=no-member
6 changes: 4 additions & 2 deletions src/relax/analysis/struct_info_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,10 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker {
auto params = finfo->params.value();
if (params.size() != call->args.size()) {
ctx->ReportFatal(Diagnostic::Error(call->span)
<< "number of arguments and parameters mismatch:"
<< " expected " << params.size() << ", given " << call->args.size());
<< "Number of arguments and parameters mismatch:"
<< " Function " << call->op << " has struct info " << finfo
<< " and accepts " << params.size() << " parameters, but was called with "
<< call->args.size() << " arguments (" << call->args << ")");
}
// Visit each param arg pair, check and populate the var map
for (size_t i = 0; i < params.size(); ++i) {
Expand Down
1 change: 1 addition & 0 deletions src/relax/backend/vm/vm_shape_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor {
collector.VisitExpr(param);
}
collector.VisitExpr(func->body);
collector.VisitStructInfo(func->ret_struct_info);
}

private:
Expand Down
4 changes: 2 additions & 2 deletions src/relax/op/tensor/inspect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field, DataType

FuncStructInfo sinfo({TensorStructInfo(DataType::Void(), kUnknownNDim)},
PrimStructInfo(field_dtype));
UpdateStructInfo(func, sinfo);
func->struct_info_ = sinfo;

return func;
}
Expand Down Expand Up @@ -338,7 +338,7 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) {
FuncStructInfo sinfo(
{TensorStructInfo(DataType::Void(), kUnknownNDim), PrimStructInfo(axis->dtype)},
PrimStructInfo(field_dtype));
UpdateStructInfo(func, sinfo);
func->struct_info_ = sinfo;
return func;
}();

Expand Down
94 changes: 94 additions & 0 deletions src/relax/transform/compute_prim_value.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>

namespace tvm {
namespace relax {

namespace {

class PrimValueComputeInjector : public ExprMutator {
public:
IRModule Finalize() const { return builder_->Finalize(); }

using ExprMutator::VisitExpr_;

Expr VisitExpr_(const PrimValueNode* op) override {
auto node = Downcast<PrimValue>(ExprMutator::VisitExpr_(op));

if (node->value->IsInstance<tir::IntImmNode>() || node->value->IsInstance<tir::VarNode>()) {
return node;
}

auto ret_dtype = node->value->dtype;
auto param_vars = tir::UndefinedVars(node->value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this call know which TIR vars are in scope per the Relax scoping rules?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call isn’t aware of the Relax scoping rules, but I don’t think there’s a benefit of checking it at this point. Any well-formed input that only uses in-scope TIR variables would produce well-formed output. Any ill-formed input that uses out-of-scope TIR variables would produce ill-formed output that still uses the out-of-scope TIR variables.

Validating the relax scoping rules at this point would require additional tracking the in-scope variables, which would duplicate the functionality of the well-formed checker. Since this pass wouldn’t make any ill-formed usage worse (and therefore harder to debug), I don’t think it’s worth duplicating the in-scope tracking here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think you're right that this would still work out just fine in that case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably overkill, but I think it ended up being simpler to just generate the appropriate FuncStructInfo for a PrimFunc on construction, and to populate the currently-empty struct_info_ field. This includes inspecting the body to see if the PrimFunc is pure.

tir::Stmt body = tir::Evaluate(tir::Call(ret_dtype, tir::builtin::ret(), {node->value}));

tir::PrimFunc func(param_vars, body, PrimType(ret_dtype));
func = tir::RenewDefs(func);

auto callee = builder_->AddFunction(func, "compute_symbolic_expr");

return relax::Call(callee, param_vars.Map([](const tir::Var& tir_var) -> relax::Expr {
return relax::PrimValue(tir_var);
}));
}
};

} // namespace

namespace transform {

Pass ComputePrimValue() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod, PassContext pc) -> IRModule {
PrimValueComputeInjector mutator;

IRModule updates;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto func = base_func.as<Function>()) {
auto updated = Downcast<Function>(mutator(func.value()));
if (!updates.same_as(base_func)) {
updates->Add(gvar, updated);
}
}
}

if (updates->functions.size()) {
auto write_ptr = mod.CopyOnWrite();
write_ptr->Update(updates);
write_ptr->Update(mutator.Finalize());
}

return mod;
};
return CreateModulePass(pass_func, 0, "ComputePrimValue", {});
}

TVM_REGISTER_GLOBAL("relax.transform.ComputePrimValue").set_body_typed(ComputePrimValue);

} // namespace transform

} // namespace relax
} // namespace tvm
Loading