From 54629104909bc6cc7927efec9aa723c090611651 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 29 Feb 2024 03:43:01 +0000 Subject: [PATCH 1/4] [TIR][Analysis] Implemented tir.analysis.is_pure_function This commit introduces two related utilities, `tir.analysis.is_pure_function` and `tir.analysis.assert_pure_function`. In contrast to the existing `tvm::tir::SideEffect`, which checks for side effects on a for a `PrimExpr`, `is_pure_function` checks for side effects for the function as a whole. --- include/tvm/tir/analysis.h | 15 ++- python/tvm/error.py | 1 + python/tvm/tir/analysis/analysis.py | 10 ++ src/tir/analysis/is_pure_function.cc | 97 ++++++++++++++++ .../test_tir_analysis_is_pure_function.py | 104 ++++++++++++++++++ 5 files changed, 226 insertions(+), 1 deletion(-) create mode 100644 src/tir/analysis/is_pure_function.cc create mode 100644 tests/python/tir-analysis/test_tir_analysis_is_pure_function.py diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index c4ae5d573be9..96459f25ecc1 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -117,13 +117,26 @@ TVM_DLL Array UndefinedVars(const PrimExpr& expr); TVM_DLL Array UndefinedVars(const PrimExpr& expr, const Array& defs); /*! - * \brief Analyze the side effect + * \brief Analyze the side effect of an expression * \param expr The expression to be checked. * * \return CallEffectKind, can be kPure, kReadState or kUpdateState */ TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr); +/*! + * \brief Analyze the side effect of a function + * + * \param func The expression to be checked. + * + * \param assert_on_error If true, an error will be thrown for an + * impure function. If false (default), the purity of the PrimFunc + * will be returned. + * + * \return The purity of the function + */ +TVM_DLL bool IsPureFunction(const PrimFunc& func, bool assert_on_error = false); + /*! * \brief Whether the given Stmt uses any var in the given variable set. * \param stmt The Stmt to be checked. diff --git a/python/tvm/error.py b/python/tvm/error.py index cc0180593d5e..6bf9b1685085 100644 --- a/python/tvm/error.py +++ b/python/tvm/error.py @@ -54,6 +54,7 @@ def __init__(self, msg): register_error("AttributeError", AttributeError) register_error("KeyError", KeyError) register_error("IndexError", IndexError) +register_error("AssertionError", AssertionError) @register_error diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 8d7e81d7d0d8..67eb7471d22d 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -417,3 +417,13 @@ def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]: returns list of passes """ return _ffi_api.get_vtcm_compaction_passes() # type: ignore # pylint: disable=no-member + + +def is_pure_function(func: PrimFunc) -> bool: + """Checks if the function is a pure function""" + return _ffi_api.is_pure_function(func, False) # type: ignore # pylint: disable=no-member + + +def assert_pure_function(func: PrimFunc) -> bool: + """Asserts that the function is a pure function""" + return _ffi_api.is_pure_function(func, True) # type: ignore # pylint: disable=no-member diff --git a/src/tir/analysis/is_pure_function.cc b/src/tir/analysis/is_pure_function.cc new file mode 100644 index 000000000000..c9934c4bcf6f --- /dev/null +++ b/src/tir/analysis/is_pure_function.cc @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file is_pure_function.cc + * \brief PrimFunc purity analysis + */ +#include +#include +#include + +#include "../ir/tir_visitor_with_path.h" + +namespace tvm { +namespace tir { + +namespace { +class PurityChecker : TIRVisitorWithPath { + public: + static bool Check(const PrimFunc& func, bool assert_on_error) { + PurityChecker visitor(assert_on_error); + visitor(func); + return visitor.is_pure_; + } + + private: + explicit PurityChecker(bool assert_on_error) : assert_on_error_(assert_on_error) {} + + void VisitStmt_(const AllocateNode* op, ObjectPath path) override { + internal_allocations_.insert(op->buffer_var); + TIRVisitorWithPath::VisitStmt_(op, path); + } + + void VisitStmt_(const BufferStoreNode* op, ObjectPath path) override { + TIRVisitorWithPath::VisitStmt_(op, path); + + if (!internal_allocations_.count(op->buffer->data)) { + is_pure_ = false; + LOG_IF(FATAL, assert_on_error_) << "AssertionError: " + << "Pure functions must not write to buffers, " + << ", but function contains store to " << op->buffer + << op->indices << " of value " << op->value; + } + } + + void VisitExpr_(const CallNode* call, ObjectPath path) override { + TIRVisitorWithPath::VisitExpr_(call, path); + + static auto op_call_effect = Op::GetAttrMap("TCallEffectKind"); + CallEffectKind effect = [&]() { + if (auto opt = call->op.as()) { + return static_cast(op_call_effect[opt.value()]->value); + } else { + return CallEffectKind::kOpaque; + } + }(); + + if (effect == CallEffectKind::kUpdateState || effect == CallEffectKind::kOpaque) { + is_pure_ = false; + LOG_IF(FATAL, assert_on_error_) + << "AssertionError: " + << "Pure functions must not contain calls to impure operators, " + << "but " << GetRef(call) << " calls operator " << call->op + << ", which has side effect " << effect; + } + } + + bool assert_on_error_{false}; + bool is_pure_{true}; + std::unordered_set internal_allocations_; +}; +} // namespace + +bool IsPureFunction(const PrimFunc& func, bool assert_on_error) { + return PurityChecker::Check(func, assert_on_error); +} + +TVM_REGISTER_GLOBAL("tir.analysis.is_pure_function").set_body_typed(IsPureFunction); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/tir-analysis/test_tir_analysis_is_pure_function.py b/tests/python/tir-analysis/test_tir_analysis_is_pure_function.py new file mode 100644 index 000000000000..6555ae3f7757 --- /dev/null +++ b/tests/python/tir-analysis/test_tir_analysis_is_pure_function.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm.testing +from tvm.script import tir as T + +from tvm.tir.analysis import is_pure_function, assert_pure_function + + +class CheckPureFunction: + def test_check_purity(self): + assert is_pure_function(self.func) + + def test_assert_purity(self): + assert_pure_function(self.func) + + +class CheckImpureFunction: + def test_check_purity(self): + assert not is_pure_function(self.func) + + def test_assert_purity(self): + with pytest.raises(AssertionError): + assert_pure_function(self.func) + + +class TestNoOp(CheckPureFunction): + @T.prim_func + def func(): + pass + + +class TestReturnValue(CheckPureFunction): + @T.prim_func + def func() -> T.int32: + T.ret(42) + + +class TestComputeValueAndReturn(CheckPureFunction): + @T.prim_func + def func(N: T.int32, M: T.int32) -> T.int32: + T.ret(N * M) + + +class TestReadBufferArgument(CheckPureFunction): + @T.prim_func + def func(A: T.Buffer(16, "float32")) -> T.float32: + T.ret(A[0]) + + +class TestWriteToBufferArgument(CheckImpureFunction): + @T.prim_func + def func(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] + + +class TestWriteToInternalAllocation(CheckPureFunction): + @T.prim_func + def func(A: T.Buffer([16, 16], "float32")) -> T.float32: + Sum = T.decl_buffer([], "float32") + Sum[()] = 0.0 + for i, j in T.grid(16, 16): + Sum[()] = Sum[()] + A[i, j] + + T.ret(Sum[()]) + + +class TestCallPureBuiltin(CheckPureFunction): + @T.prim_func + def func(x: T.float32) -> T.float32: + T.ret(T.cos(x)) + + +class TestCallPureExtern(CheckPureFunction): + @T.prim_func + def func(): + T.call_pure_extern("some_pure_extern_func_name", dtype="void") + + +class TestCallImpureExtern(CheckImpureFunction): + @T.prim_func + def func(): + T.call_extern("some_impure_extern_func_name", dtype="void") + + +if __name__ == "__main__": + tvm.testing.main() From 2c64c81d45eaffda364c536aa0b3aadec7c367dc Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 23 Feb 2024 20:00:40 +0000 Subject: [PATCH 2/4] [Transform] Implement relax.transform.ComputePrimValue Prior to this commit, while expressions of type `DataType::Int(64)` could be computed in the `relax.transform.VMShapeLower`, expressions of any other type could not. This commit introduces `relax.transform.ComputePrimValue`, which produces `PrimFunc` subroutines to compute `PrimExpr` values of any dtype. This functionality will allow boolean values to be computed based on the symbolic values known at runtime. --- python/tvm/relax/transform/__init__.py | 1 + python/tvm/relax/transform/transform.py | 10 ++ src/relax/transform/compute_prim_value.cc | 94 +++++++++++++++++++ .../test_transform_compute_prim_value.py | 80 ++++++++++++++++ 4 files changed, 185 insertions(+) create mode 100644 src/relax/transform/compute_prim_value.cc create mode 100644 tests/python/relax/test_transform_compute_prim_value.py diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 5f10c39d825b..11e301c26cca 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -28,6 +28,7 @@ CallTIRRewrite, CanonicalizeBindings, CombineParallelMatmul, + ComputePrimValue, ConvertLayout, ConvertToDataflow, DataflowBlockPass, diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index ef10f5791dbb..7daf7ef04716 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -486,6 +486,16 @@ def KillAfterLastUse() -> tvm.ir.transform.Pass: return _ffi_api.KillAfterLastUse() # type: ignore +def ComputePrimValue() -> tvm.ir.transform.Pass: + """Compute all R.prim_value instances + + Returns + ------- + ret : tvm.ir.transform.Pass + """ + return _ffi_api.ComputePrimValue() # type: ignore + + def VMBuiltinLower() -> tvm.ir.transform.Pass: """Lowering generic intrinsic to VM intrinsics. diff --git a/src/relax/transform/compute_prim_value.cc b/src/relax/transform/compute_prim_value.cc new file mode 100644 index 000000000000..9fe2a3a06fb7 --- /dev/null +++ b/src/relax/transform/compute_prim_value.cc @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relax { + +namespace { + +class PrimValueComputeInjector : public ExprMutator { + public: + IRModule Finalize() const { return builder_->Finalize(); } + + using ExprMutator::VisitExpr_; + + Expr VisitExpr_(const PrimValueNode* op) override { + auto node = Downcast(ExprMutator::VisitExpr_(op)); + + if (node->value->IsInstance() || node->value->IsInstance()) { + return node; + } + + auto ret_dtype = node->value->dtype; + auto param_vars = tir::UndefinedVars(node->value); + tir::Stmt body = tir::Evaluate(tir::Call(ret_dtype, tir::builtin::ret(), {node->value})); + + tir::PrimFunc func(param_vars, body, PrimType(ret_dtype)); + func = tir::RenewDefs(func); + + auto callee = builder_->AddFunction(func, "compute_symbolic_expr"); + + return relax::Call(callee, param_vars.Map([](const tir::Var& tir_var) -> relax::Expr { + return relax::PrimValue(tir_var); + })); + } +}; + +} // namespace + +namespace transform { + +Pass ComputePrimValue() { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, PassContext pc) -> IRModule { + PrimValueComputeInjector mutator; + + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + auto updated = Downcast(mutator(func.value())); + if (!updates.same_as(base_func)) { + updates->Add(gvar, updated); + } + } + } + + if (updates->functions.size()) { + auto write_ptr = mod.CopyOnWrite(); + write_ptr->Update(updates); + write_ptr->Update(mutator.Finalize()); + } + + return mod; + }; + return CreateModulePass(pass_func, 0, "ComputePrimValue", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.ComputePrimValue").set_body_typed(ComputePrimValue); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_transform_compute_prim_value.py b/tests/python/relax/test_transform_compute_prim_value.py new file mode 100644 index 000000000000..d746272192d5 --- /dev/null +++ b/tests/python/relax/test_transform_compute_prim_value.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import ir as I, relax as R, tir as T + + +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.relax.transform.ComputePrimValue() + + +class TestPrimValueInAssertCondition(BaseCompare): + @I.ir_module + class Before: + @R.function(pure=False) + def main(A: R.Tensor(["N"])): + N = T.int64() + _ = R.assert_op(N % 16 == 0) + return A + + @I.ir_module + class Expected: + @R.function(pure=False) + def main(A: R.Tensor(["N"])): + N = T.int64() + condition: R.Prim("bool") = Expected.compute_symbolic_expr(R.prim_value(N)) + _ = R.assert_op(condition) + return A + + @T.prim_func(private=True) + def compute_symbolic_expr(N: T.int64) -> T.bool: + T.ret(N % 16 == 0) + + +class TestPrimValueInBranchCondition(BaseCompare): + @I.ir_module + class Before: + @R.function(pure=False) + def main(A: R.Tensor(["N"])): + N = T.int64() + if R.prim_value(N % 16 == 0): + out = R.call_packed("fast_vectorized_impl", A, sinfo_args=[A.struct_info]) + else: + out = R.call_packed("slow_non_vectorized_impl", A, sinfo_args=[A.struct_info]) + return out + + @I.ir_module + class Expected: + @R.function(pure=False) + def main(A: R.Tensor(["N"])): + N = T.int64() + condition: R.Prim("bool") = Expected.compute_symbolic_expr(R.prim_value(N)) + if condition: + out = R.call_packed("fast_vectorized_impl", A, sinfo_args=[A.struct_info]) + else: + out = R.call_packed("slow_non_vectorized_impl", A, sinfo_args=[A.struct_info]) + return out + + @T.prim_func(private=True) + def compute_symbolic_expr(N: T.int64) -> T.bool: + T.ret(N % 16 == 0) + + +if __name__ == "__main__": + tvm.testing.main() From f5547383120988df029c5ebd3fbb28c929c9b16b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 23 Feb 2024 20:49:37 +0000 Subject: [PATCH 3/4] [Relax] Allow R.Prim('bool') in relax::If and assert_op Prior to this commit, the condition used for `relax::If` node and the `"relax.assert_op"` operator was required to be a scalar tensor. This made it difficult to alter behavior based on a runtime shape parameter. For example, delegating to a vectorized implementation based on a whether a tensor shape is divisible by the vector size. This commit adds support for expressions of type `R.Prim('bool')` as the conditional for `relax::If` and `"relax.assert_op"`, to allow these use cases. --- python/tvm/relax/op/base.py | 44 ++-- python/tvm/relax/pipeline.py | 1 + python/tvm/relax/transform/transform.py | 9 + python/tvm/script/ir_builder/relax/ir.py | 15 +- python/tvm/script/parser/tir/parser.py | 33 ++- src/relax/analysis/struct_info_analysis.cc | 6 +- src/relax/backend/vm/vm_shape_lower.cc | 1 + src/relax/op/tensor/inspect.cc | 4 +- src/relax/transform/dataflow_inplace.cc | 45 ++-- src/relax/utils.cc | 17 +- src/tir/ir/function.cc | 43 ++++ src/tir/ir/specialize.cc | 10 +- src/tir/transforms/renew_defs.cc | 6 +- .../python/relax/test_analysis_well_formed.py | 45 ++++ .../test_backend_transform_shape_lower.py | 84 ++++++++ tests/python/relax/test_relax_operators.py | 195 +++++++++++++----- tests/python/relax/test_transform.py | 12 +- .../test_transform_compute_prim_value.py | 24 +++ tests/python/relax/test_tvmscript_parser.py | 147 ++++++++++++- tests/python/relax/test_vm_codegen_tir.py | 2 +- tests/python/tir-base/test_tir_specialize.py | 27 +-- .../tvmscript/test_tvmscript_parser_tir.py | 109 ++++++++++ 22 files changed, 742 insertions(+), 137 deletions(-) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 3effec242d64..756d250c1687 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -503,19 +503,26 @@ def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Ob f"The format string argument to assert must be a string, given {type(format_str)})" ) - # should be guaranteed by the type system - if not isinstance(condition, tvm.nd.NDArray): - raise ValueError(f"The condition must be an NDArray, but given a {type(condition)}.") - - # may happen if the original program had unknown shape or dtype for the tensor's type - dtype = condition.dtype - if dtype != "bool": - raise ValueError(f"The condition must be a bool scalar, but given a {dtype} tensor") - shape = condition.shape - if len(shape) != 0: - raise ValueError(f"The condition must be a scalar, but it has a shape of {shape}") - - val = condition.numpy() + if isinstance(condition, (bool, int)): + val = condition + elif isinstance(condition, tvm.nd.NDArray): + # may happen if the original program had unknown shape or dtype for the tensor's type + dtype = condition.dtype + if dtype != "bool": + raise ValueError(f"The condition must be a bool scalar, but given a {dtype} tensor") + shape = condition.shape + if len(shape) != 0: + raise ValueError(f"The condition must be a scalar, but it has a shape of {shape}") + + val = condition.numpy() + + else: + # should be guaranteed by the type system + raise ValueError( + f"The condition for relax assert must be a bool, int, or NDArray, " + f"but received a {type(condition)}." + ) + if not val: error_message = "Assertion Failed" if format_args or format_str != "": @@ -528,7 +535,7 @@ def relax_assert_op(condition: tvm.Object, format_str: str, *format_args: tvm.Ob def assert_op( - condition: Expr, + condition: Union[Expr, PrimExpr], format_args: Optional[Union[Expr, List[Expr]]] = None, format: Union[str, Expr] = "", ) -> Expr: @@ -538,7 +545,7 @@ def assert_op( Parameters ---------- - condition: Expr + condition: Union[Expr, PrimExpr] The assertion condition. format_args: Optional[Union[Expr, List[Expr]]] @@ -552,12 +559,17 @@ def assert_op( result : Expr A Call to the Relax assert operation. """ + if not isinstance(condition, Expr): + condition = tvm.relax.PrimValue(condition) + if format_args is None: format_args = [] - if isinstance(format_args, Expr): # type: ignore + elif isinstance(format_args, Expr): format_args = [format_args] + if isinstance(format, str): format = StringImm(format) + return _ffi_api.assert_op(condition, format_args, format) # type: ignore diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index 474833bdfdcf..36ba46a1a5e3 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -92,6 +92,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I transform.LowerAllocTensor(), transform.KillAfterLastUse(), transform.VMBuiltinLower(), + transform.ComputePrimValue(), transform.VMShapeLower(), transform.AttachGlobalSymbol(), ], diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 7daf7ef04716..dbc35d48d303 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -489,9 +489,18 @@ def KillAfterLastUse() -> tvm.ir.transform.Pass: def ComputePrimValue() -> tvm.ir.transform.Pass: """Compute all R.prim_value instances + While high-level relax can include expressions in terms of its + symbolic variables, these expressions cannot natively be computed + within relax. In order to provide values for symbolic expressions + (e.g. `R.prim_value(N*N)`, where `N` is a symbolic variable), this + pass generates a PrimFunc in which the expression can be computed. + The relax graph is then updated to include a call to that + PrimFunc, in place of the original `R.prim_value(expr)`. + Returns ------- ret : tvm.ir.transform.Pass + """ return _ffi_api.ComputePrimValue() # type: ignore diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 3e1927290dcc..6dbf5c5dfdb4 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -511,18 +511,25 @@ def SeqExpr() -> frame.SeqExprFrame: # pylint: disable=invalid-name ############################# If Then Else ############################# -def If(condition: Expr) -> frame.IfFrame: # pylint: disable=invalid-name +def If(condition: Union[Expr, PrimExpr]) -> frame.IfFrame: # pylint: disable=invalid-name """Create an if frame. + Parameters ---------- - condition : Expr - The condition of if statement, executes the true branch if the condition is true, - otherwise jump into the false branch. + condition : Union[Expr, PrimExpr] + + The condition of if statement, executes the true branch if the + condition is true, otherwise jump into the false branch. + Returns ------- res : frame.IfFrame The result IfFrame. + """ + if not isinstance(condition, Expr): + condition = relax.PrimValue(condition) + return _ffi_api.If(condition) # type: ignore[attr-defined] # pylint: disable=no-member diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 0f3f3de60fe3..679ae4e8adc0 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -537,12 +537,31 @@ def visit_tvm_declare_function(self: Parser, node: doc.FunctionDef) -> GlobalVar The doc AST return node. """ - ret_type = None - if node.returns is not None: - ret_type = self.eval_expr(node.returns) - if callable(ret_type): - ret_type = PrimType(ret_type().dtype) + supplied_annotation = self.function_annotations + func_annotation = supplied_annotation.get(node.name, {}) - # Only ret_type is needed for func_signature. - func_signature = tvm.tir.PrimFunc([], None, ret_type=ret_type) + ret_type = None + with self.var_table.with_frame(): + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + + arg_annotations = [] + for arg in node.args.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation required for function parameters.") + try: + ann = self.eval_expr(arg.annotation) + if callable(ann): + ann = ann() + except Exception: # pylint: disable=broad-except + ann = func_annotation.get(arg.arg, None) + if ann is None: + raise + + IRBuilder.name(arg.arg, ann) + arg_annotations.append(ann) + + func_signature = tvm.tir.PrimFunc(arg_annotations, None, ret_type=ret_type) return I.decl_function(node.name, func_signature) diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index b1932f9b5d67..08e2acfbd069 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -840,8 +840,10 @@ class CallRetStructInfoDeriver : public StructInfoBaseChecker { auto params = finfo->params.value(); if (params.size() != call->args.size()) { ctx->ReportFatal(Diagnostic::Error(call->span) - << "number of arguments and parameters mismatch:" - << " expected " << params.size() << ", given " << call->args.size()); + << "Number of arguments and parameters mismatch:" + << " Function " << call->op << " has struct info " << finfo + << " and accepts " << params.size() << " parameters, but was called with " + << call->args.size() << " arguments (" << call->args << ")"); } // Visit each param arg pair, check and populate the var map for (size_t i = 0; i < params.size(); ++i) { diff --git a/src/relax/backend/vm/vm_shape_lower.cc b/src/relax/backend/vm/vm_shape_lower.cc index 5875ad55628c..84d9f4cf2366 100644 --- a/src/relax/backend/vm/vm_shape_lower.cc +++ b/src/relax/backend/vm/vm_shape_lower.cc @@ -85,6 +85,7 @@ class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor { collector.VisitExpr(param); } collector.VisitExpr(func->body); + collector.VisitStructInfo(func->ret_struct_info); } private: diff --git a/src/relax/op/tensor/inspect.cc b/src/relax/op/tensor/inspect.cc index 186fc9fa8690..3772e530edf7 100644 --- a/src/relax/op/tensor/inspect.cc +++ b/src/relax/op/tensor/inspect.cc @@ -107,7 +107,7 @@ tir::PrimFunc GetDLTensorField(tir::builtin::TVMStructFieldKind field, DataType FuncStructInfo sinfo({TensorStructInfo(DataType::Void(), kUnknownNDim)}, PrimStructInfo(field_dtype)); - UpdateStructInfo(func, sinfo); + func->struct_info_ = sinfo; return func; } @@ -338,7 +338,7 @@ Expr LegalizeTensorShape(const BlockBuilder& bb, const Call& call) { FuncStructInfo sinfo( {TensorStructInfo(DataType::Void(), kUnknownNDim), PrimStructInfo(axis->dtype)}, PrimStructInfo(field_dtype)); - UpdateStructInfo(func, sinfo); + func->struct_info_ = sinfo; return func; }(); diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 755c5dbab433..091298177595 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -877,10 +877,12 @@ class ModuleInplaceTransformer : public ExprMutator { auto inline_legal_op_name = legal_op->name_hint + "_inplace"; auto mod = builder_->GetContextIRModule(); - auto legal_primfunc = Downcast(mod->Lookup(legal_op)); - auto* legal_primfunc_cow = legal_primfunc.CopyOnWrite(); + auto old_primfunc = Downcast(mod->Lookup(legal_op)); + + tir::Stmt new_body = old_primfunc->body; + size_t num_outs = inplace_indices.size(); - size_t num_params = legal_primfunc->params.size(); + size_t num_params = old_primfunc->params.size(); // the replacement we must make: // 1. For each output var, replace its corresponding buffers with the corresponding inplace @@ -893,42 +895,43 @@ class ModuleInplaceTransformer : public ExprMutator { Map var_subst_map; for (size_t i = 0; i < num_outs; i++) { // we will substitute output i with the corresponding param indicated by inplace indices - auto output_var = legal_primfunc->params[num_params - num_outs + i]; - auto inplace_var = legal_primfunc->params[inplace_indices[i].IntValue()]; + auto output_var = old_primfunc->params[num_params - num_outs + i]; + auto inplace_var = old_primfunc->params[inplace_indices[i].IntValue()]; var_subst_map.Set(output_var, inplace_var); // also do the same with the buffer vars - auto output_buffer = legal_primfunc->buffer_map.at(output_var); - auto inplace_buffer = legal_primfunc->buffer_map.at(inplace_var); + auto output_buffer = old_primfunc->buffer_map.at(output_var); + auto inplace_buffer = old_primfunc->buffer_map.at(inplace_var); var_subst_map.Set(output_buffer->data, inplace_buffer->data); buffer_subst_map.Set(output_buffer, inplace_buffer); } // apply substitutions - legal_primfunc_cow->body = RemapBuffers(legal_primfunc->body, buffer_subst_map); - legal_primfunc_cow->body = tir::Substitute( - legal_primfunc->body, [&var_subst_map](const tir::Var& v) -> Optional { - if (var_subst_map.count(v)) { - return var_subst_map.at(v); - } - return Optional(); - }); + new_body = RemapBuffers(new_body, buffer_subst_map); + new_body = tir::Substitute(new_body, [&var_subst_map](const tir::Var& v) -> Optional { + if (var_subst_map.count(v)) { + return var_subst_map.at(v); + } + return Optional(); + }); // remove the now-unused outputs from the buffer map - auto buffer_map = legal_primfunc->buffer_map; + auto new_buffer_map = old_primfunc->buffer_map; for (size_t i = 0; i < num_outs; i++) { - buffer_map.erase(legal_primfunc->params[num_params - num_outs + i]); + new_buffer_map.erase(old_primfunc->params[num_params - num_outs + i]); } - legal_primfunc_cow->buffer_map = buffer_map; // now get rid of the last num_outputs arguments // (couldn't do earlier or else it would have thrown off the indexing) - legal_primfunc_cow->params = Array( - legal_primfunc->params.begin(), legal_primfunc->params.begin() + (num_params - num_outs)); + Array new_params(old_primfunc->params.begin(), + old_primfunc->params.begin() + (num_params - num_outs)); + + tir::PrimFunc new_primfunc(new_params, new_body, old_primfunc->ret_type, new_buffer_map, + old_primfunc->attrs, old_primfunc->span); // note: this might be a good time to get rid of the old legalized function, but we don't do it // now because later ops might need the same one. Instead, we will clean up at the end - auto new_gv = builder_->AddFunction(legal_primfunc, inline_legal_op_name); + auto new_gv = builder_->AddFunction(new_primfunc, inline_legal_op_name); // update the call (change the op, update the argument, change the attrs) legalized_call_cow->op = call_tir_inplace_op; diff --git a/src/relax/utils.cc b/src/relax/utils.cc index efb2d0220481..a15ee79facbf 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -220,12 +220,21 @@ tvm::Map InferSymbolicVarMap( bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank, bool permit_unknown_dtype) { - const TensorStructInfoNode* tt = sinfo.as(); - if (!tt) { + DataType dtype; + int ndim; + + if (const auto* tensor = sinfo.as()) { + dtype = tensor->dtype; + ndim = tensor->ndim; + } else if (const auto* prim = sinfo.as()) { + dtype = prim->dtype; + ndim = 0; + } else { return false; } - bool correct_dtype = tt->dtype.is_bool() || (permit_unknown_dtype && tt->dtype.is_void()); - bool correct_rank = tt->ndim == 0 || (permit_unknown_rank && tt->ndim == -1); + + bool correct_dtype = dtype.is_bool() || (permit_unknown_dtype && dtype.is_void()); + bool correct_rank = ndim == 0 || (permit_unknown_rank && ndim == -1); return correct_dtype && correct_rank; } diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 5067d9083863..8a3d2d69474f 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -21,12 +21,52 @@ * \file src/tir/ir/function.cc * \brief The function data structure. */ +#include #include +#include #include #include namespace tvm { namespace tir { +namespace { +relax::StructInfo InferStructInfo(const PrimFunc& prim_func) { + Array params; + for (const auto& param : prim_func->params) { + relax::StructInfo param_sinfo = [&]() -> relax::StructInfo { + if (auto opt_buf = prim_func->buffer_map.Get(param)) { + auto buf = opt_buf.value(); + relax::ShapeExpr shape( + buf->shape.Map([](PrimExpr dim) { return cast(DataType::Int(64), dim); })); + return relax::TensorStructInfo(shape, buf->dtype); + } + + if (auto prim_type = param->type_annotation.as(); + prim_type && prim_type->dtype.is_handle()) { + return relax::ObjectStructInfo(); + } + + return relax::PrimStructInfo(param->dtype); + }(); + params.push_back(param_sinfo); + } + + relax::StructInfo ret = [&]() -> relax::StructInfo { + if (const auto* prim = prim_func->ret_type.as()) { + return relax::PrimStructInfo(prim->dtype); + } else if (IsVoidType(prim_func->ret_type)) { + return relax::TupleStructInfo(Array{}); + } else { + return relax::ObjectStructInfo(); + } + }(); + + bool purity = prim_func->body.defined() ? IsPureFunction(prim_func) : false; + + return relax::FuncStructInfo(params, ret, purity); +} +} // namespace + // Get the function type of a PrimFunc PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, Map buffer_map, DictAttrs attrs, Span span) { @@ -42,8 +82,11 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, n->buffer_map = std::move(buffer_map); n->attrs = std::move(attrs); n->checked_type_ = n->func_type_annotation(); + n->struct_info_ = relax::FuncStructInfo::OpaqueFunc(); n->span = std::move(span); data_ = std::move(n); + + (*this)->struct_info_ = InferStructInfo(*this); } FuncType PrimFuncNode::func_type_annotation() const { diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 8095b3141fbf..924ef9a0cdde 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -105,14 +105,10 @@ class PrimFuncSpecializer : public StmtExprMutator { Stmt body = specializer(f->body); if (param_updated || buffer_map_updated || !f->body.same_as(body)) { - PrimFuncNode* f_ptr = f.CopyOnWrite(); - f_ptr->params = std::move(params); - f_ptr->buffer_map = std::move(buffer_map); - f_ptr->body = std::move(body); - f_ptr->struct_info_ = NullOpt; - f_ptr->checked_type_ = Type(nullptr); + return PrimFunc(params, body, f->ret_type, buffer_map, f->attrs, f->span); + } else { + return f; } - return f; } private: diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc index 8a122f892204..28d1100f6b53 100644 --- a/src/tir/transforms/renew_defs.cc +++ b/src/tir/transforms/renew_defs.cc @@ -76,11 +76,7 @@ class RenewDefMutator : public StmtExprMutator { // Visit body Stmt body = generator(func->body); // Recreate function - auto n = make_object(*func.get()); - n->params = std::move(params); - n->buffer_map = std::move(buffer_map); - n->body = std::move(body); - return PrimFunc(n); + return PrimFunc(params, body, func->ret_type, buffer_map, func->attrs, func->span); } private: diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index b76b95646a72..fe171d9606c9 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -22,6 +22,7 @@ from tvm.script import relax as R from tvm.script import ir as I from tvm.script import tir as T +from tvm.script import ir as I m = tir.Var("m", "int64") n = tir.Var("n", "int64") @@ -655,6 +656,50 @@ def subroutine(A: R.Tensor([16, 32], "float32"), B: R.Tensor([32, 64], "float32" assert rx.analysis.well_formed(Module["main"]) assert rx.analysis.well_formed(Module["subroutine"]) +def test_pass_dltensor_arg_to_tir(): + """Relax may pass R.Tensor as DLTensor + + In TIR, a `DLTensor*` argument with unknown shape and dtype is + represented as a `tir.Var` with + `tvm::PrimType(DataType::Handle())`, and with no entry in the + `PrimFuncNode::buffer_map`. In Relax, this is represented as + `R.Tensor`. Calls from Relax to TIR that pass a tensor of unknown + rank/shape are well-formed. + + In the test case below, a TIR function accepts an arbitrary + `R.Tensor`, and returns a boolean value based on inspection of the + runtime datatype. + """ + + @I.ir_module + class Module: + @R.function + def main(A: R.Tensor) -> R.Prim("bool"): + return Module.is_bfloat16_dtype(A) + + @T.prim_func(private=True) + def is_bfloat16_dtype(tensor: T.handle) -> T.bool: + T.func_attr({"tir.is_scheduled": True, "tir.is_host_func": True}) + + # From #include + kArrTypeCode = T.meta_var(5) + kArrTypeBits = T.meta_var(6) + kArrTypeLanes = T.meta_var(7) + + # From #include + kDLBfloat = T.meta_var(4) + + type_code = T.tvm_struct_get(tensor, 0, kArrTypeCode, dtype="uint8") + type_bits = T.tvm_struct_get(tensor, 0, kArrTypeBits, dtype="uint8") + type_lanes = T.tvm_struct_get(tensor, 0, kArrTypeLanes, dtype="uint16") + + is_bfloat16: T.bool = ( + (type_code == kDLBfloat) and (type_bits == 16) and (type_lanes == 1) + ) + return is_bfloat16 + + assert rx.analysis.well_formed(Module) + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index 31eb4b26bee0..fccf3a5f8a1e 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -452,6 +452,90 @@ def main( assert_structural_equal(after, expected) +def test_return_match_check_with_new_expr(): + """Like test_return_match_check, but requires a computation + + When return body is not same as ret_struct_info, a runtime match + check is required. This match check may require a symbolic + expression to be computed. + """ + MS = MatchShapeCode + + @tvm.script.ir_module + class Before: + @R.function + def main(x: R.Tensor(["n", "n"], "float32")) -> R.Tensor(["n * n"], "float32"): + R.func_attr({"relax.force_pure": True}) + out = R.call_packed("flatten_matrix", x, sinfo_args=R.Object) + return out + + # slot assignment: + sindex = { + "n": 0, + "n * n": 1, + } + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(["n", "n"], "float32")) -> R.Tensor(["n * n"], "float32"): + R.func_attr({"relax.force_pure": True}) + shape_heap = R.call_builtin_with_ctx( + "vm.builtin.alloc_shape_heap", + [R.prim_value(2)], + sinfo_args=[R.Tensor(ndim=1, dtype="int64")], + ) + _ = R.call_packed( + "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "", sinfo_args=[R.Tuple()] + ) + _ = R.call_packed( + "vm.builtin.match_shape", + x, + shape_heap, + 2, + MS.STORE_TO_HEAP, + sindex["n"], + MS.ASSERT_EQUAL_TO_LOAD, + sindex["n"], + "", + sinfo_args=[R.Tuple()], + ) + + _ = Expected.shape_func(shape_heap) + + out = R.call_packed("flatten_matrix", x, sinfo_args=R.Object) + _ = R.call_packed( + "vm.builtin.check_tensor_info", + out, + 1, + R.dtype("float32"), + "", + sinfo_args=[R.Tuple()], + ) + _ = R.call_packed( + "vm.builtin.match_shape", + out, + shape_heap, + 1, + MS.ASSERT_EQUAL_TO_LOAD, + sindex["n * n"], + "", + sinfo_args=[R.Tuple()], + ) + return out + + @T.prim_func(private=True) + def shape_func(H: T.Buffer(T.int64(2), "int64")): + # generated compute function + T.func_attr({"tir.is_host_func": 1}) + H[T.int64(sindex["n * n"])] = H[T.int64(sindex["n"])] * H[T.int64(sindex["n"])] + + before = Before + expected = Expected + after = relax.transform.VMShapeLower(emit_err_ctx=False)(before) + assert_structural_equal(after, expected) + + def test_symbolic_shape_multiple_function(): MS = MatchShapeCode MK = MakeShapeCode diff --git a/tests/python/relax/test_relax_operators.py b/tests/python/relax/test_relax_operators.py index a278b0916772..41618a32cb55 100644 --- a/tests/python/relax/test_relax_operators.py +++ b/tests/python/relax/test_relax_operators.py @@ -19,6 +19,8 @@ import tempfile import numpy as np +import pytest + import tvm import tvm.testing from tvm import relax @@ -35,13 +37,18 @@ def foo(x: R.Tensor(("m", "n"), "int64")): return y, y_sorted -def run_cpu(mod, func_name, *input): +def run_cpu(mod, func_name, *args): + if isinstance(mod, relax.Function): + func = mod + args = [func_name, *args] + func_name = func.attrs["global_symbol"] + mod = tvm.IRModule.from_expr(func) + target = tvm.target.Target("llvm") ex = relax.build(mod, target) vm = relax.VirtualMachine(ex, tvm.cpu()) - vm.set_input(func_name, *input) - vm.invoke_stateful(func_name) - return vm.get_outputs(func_name) + + return vm[func_name](*args) def test_unique(): @@ -88,67 +95,108 @@ def test_print(): sys.stdout = stdout -@tvm.script.ir_module -class AssertOpTest: +def test_assert_passes(): @R.function(pure=False) - def passes(x: R.Tensor((), "int32")): - p1 = R.assert_op(relax.const(True)) + def func(x: R.Tensor((), "int32")): + _ = R.assert_op(relax.const(True)) return x + run_cpu(func, tvm.nd.array(np.array(1).astype("int32"))) + + +def test_assert_passes_with_format_args(): @R.function(pure=False) - def pass_with_args(x: R.Tensor((), "int32")): - p1 = R.assert_op(relax.const(True), x, format="You won't see me") + def func(x: R.Tensor((), "int32")): + _ = R.assert_op(relax.const(True), x, format="You won't see me") return x + run_cpu(func, tvm.nd.array(np.array(1).astype("int32"))) + + +def test_assert_fails(): + @R.function(pure=False) + def func(x: R.Tensor((), "int32")): + _ = R.assert_op(relax.const(False)) + return x + + with pytest.raises(AssertionError, match="Assertion Failed"): + run_cpu(func, tvm.nd.array(np.array(1).astype("int32"))) + + +def test_assert_fails_with_message(): @R.function(pure=False) - def simple_fail(x: R.Tensor((), "int32")): - p1 = R.assert_op(relax.const(False)) + def func(x: R.Tensor((), "int32")): + _ = R.assert_op(relax.const(False), format="I failed...") return x + with pytest.raises(AssertionError, match="I failed..."): + run_cpu(func, tvm.nd.array(np.array(1).astype("int32"))) + + +def test_assert_fails_with_args(): @R.function(pure=False) - def fail_with_message(x: R.Tensor((), "int32")): - p1 = R.assert_op(relax.const(False), format="I failed...") + def func(x: R.Tensor((), "int32")): + _ = R.assert_op(relax.const(False), [x, x]) return x + with pytest.raises(AssertionError, match="5, 5"): + run_cpu(func, tvm.nd.array(np.array(5).astype("int32"))) + + +def test_assert_fails_with_formatted_args(): @R.function(pure=False) - def fail_with_args(x: R.Tensor((), "int32")): - # no format - p1 = R.assert_op(relax.const(False), [x, x]) + def func(x: R.Tensor((), "int32")): + _ = R.assert_op(relax.const(False), x, format="Number: {}") return x + with pytest.raises(AssertionError, match="Number: 6"): + run_cpu(func, tvm.nd.array(np.array(6).astype("int32"))) + + +def test_assert_on_argument_passes(): @R.function(pure=False) - def fail_with_formatted_message(x: R.Tensor((), "int32")): - p1 = R.assert_op(relax.const(False), x, format="Number: {}") + def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")): + _ = R.assert_op(condition) return x + condition = tvm.nd.array(np.array(True)) + x = tvm.nd.array(np.array(5).astype("int32")) + run_cpu(func, condition, x) -def test_assert_op(): - def check_assertion_error(func_name, func_arg, expected_message): - passed = False - try: - run_cpu(AssertOpTest, func_name, func_arg) - passed = True - except TVMError as e: - # TVM will print out a TVMError that will contain the - # generated error at the bottom of a stack trace - assert "AssertionError" in e.args[0] - assert expected_message in e.args[0] - except AssertionError: - return - assert not passed - - run_cpu(AssertOpTest, "passes", tvm.nd.array(np.array(1).astype("int32"))) - run_cpu(AssertOpTest, "pass_with_args", tvm.nd.array(np.array(2).astype("int32"))) - check_assertion_error( - "simple_fail", tvm.nd.array(np.array(3).astype("int32")), "Assertion Failed" - ) - check_assertion_error( - "fail_with_message", tvm.nd.array(np.array(4).astype("int32")), "I failed..." - ) - check_assertion_error("fail_with_args", tvm.nd.array(np.array(5).astype("int32")), "5, 5") - check_assertion_error( - "fail_with_formatted_message", tvm.nd.array(np.array(6).astype("int32")), "Number: 6" - ) + +def test_assert_on_argument_fails(): + @R.function(pure=False) + def func(condition: R.Tensor((), "bool"), x: R.Tensor((), "int32")): + _ = R.assert_op(condition) + return x + + condition = tvm.nd.array(np.array(False)) + x = tvm.nd.array(np.array(5).astype("int32")) + with pytest.raises(AssertionError): + run_cpu(func, condition, x) + + +def test_assert_on_symbolic_var_passes(): + @R.function(pure=False) + def func(x: R.Tensor(["N"], "int32")): + N = T.int64() + _ = R.assert_op(R.prim_value(N % 8 == 0)) + return x + + x = tvm.nd.array(np.arange(8, dtype="int32")) + run_cpu(func, x) + + +def test_assert_on_symbolic_var_fails(): + @R.function(pure=False) + def func(x: R.Tensor(["N"], "int32")): + N = T.int64() + _ = R.assert_op(R.prim_value(N % 8 == 0)) + return x + + x = tvm.nd.array(np.arange(10, dtype="int32")) + with pytest.raises(AssertionError): + run_cpu(func, x) @tvm.script.ir_module @@ -370,5 +418,60 @@ def to_vdev(x: R.Tensor((3, 4), "float32")): assert (copy_found.numpy() == arr).all() +def test_scalar_tensor_as_branch_condition(): + """The condition of a branch may be a scalar tensor""" + + @R.function + def func(condition: R.Tensor((), "bool")): + if condition: + out = R.prim_value(5) + else: + out = R.prim_value(10) + return out + + res = run_cpu(func, tvm.nd.array(np.array(True))) + assert res == 5 + + res = run_cpu(func, tvm.nd.array(np.array(False))) + assert res == 10 + + +def test_prim_value_as_branch_condition(): + """The condition may be a PrimValue""" + + @R.function + def func(condition: R.Prim("bool")): + if condition: + out = R.prim_value(5) + else: + out = R.prim_value(10) + return out + + res = run_cpu(func, True) + assert res == 5 + + res = run_cpu(func, False) + assert res == 10 + + +def test_computed_prim_value_as_branch_condition(): + """The R.Prim condition may be computed within the function""" + + @R.function + def func(x: R.Tensor(["N"], "int64")): + N = T.int64() + if R.prim_value(N % 16 == 0): + out = R.prim_value(5) + else: + out = R.prim_value(10) + return out + + res = run_cpu(func, tvm.nd.array(np.arange(16))) + assert res == 5 + + res = run_cpu(func, tvm.nd.array(np.arange(20))) + assert res == 10 + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 9ab2ffc60536..7fbf9a2da141 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -343,14 +343,18 @@ def foo( @tvm.script.ir_module class Expected: @T.prim_func - def copy(A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32")): + def copy( + A: T.Buffer((2, 3), "int32"), B: T.Buffer((2, 3), "int32"), C: T.Buffer((2, 3), "int32") + ): + # copies the contents of C into A and B T.func_attr({"tir.noalias": True}) for i0, i1 in T.grid(T.int64(2), T.int64(3)): with T.block("T_zeros"): ax0, ax1 = T.axis.remap("SS", [i0, i1]) - T.reads(B[ax0, ax1]) - T.writes(A[ax0, ax1]) - A[ax0, ax1] = B[ax0, ax1] + T.reads(C[ax0, ax1]) + T.writes(A[ax0, ax1], B[ax0, ax1]) + A[ax0, ax1] = C[ax0, ax1] + B[ax0, ax1] = C[ax0, ax1] @R.function def foo( diff --git a/tests/python/relax/test_transform_compute_prim_value.py b/tests/python/relax/test_transform_compute_prim_value.py index d746272192d5..9fee35414d0d 100644 --- a/tests/python/relax/test_transform_compute_prim_value.py +++ b/tests/python/relax/test_transform_compute_prim_value.py @@ -76,5 +76,29 @@ def compute_symbolic_expr(N: T.int64) -> T.bool: T.ret(N % 16 == 0) +class TestPrimValueInPureFunction(BaseCompare): + @I.ir_module + class Before: + @R.function + def main(_N: R.Prim(value="N"), _M: R.Prim(value="M")) -> R.Prim(value="N*M"): + N = T.int64() + M = T.int64() + out = R.prim_value(N * M) + return out + + @I.ir_module + class Expected: + @R.function + def main(_N: R.Prim(value="N"), _M: R.Prim(value="M")) -> R.Prim(value="N*M"): + N = T.int64() + M = T.int64() + out = Expected.compute_symbolic_expr(R.prim_value(N), R.prim_value(M)) + return out + + @T.prim_func(private=True) + def compute_symbolic_expr(N: T.int64, M: T.int64) -> T.int64: + T.ret(N * M) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 2221cb89eb20..c8db26c81bac 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1261,6 +1261,149 @@ def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")): return w +def test_scalar_tensor_as_branch_condition(): + """Branch condition can be 0-d tensor""" + + @R.function + def func(cond: R.Tensor([], "bool"), x: R.Tensor((1,), "float32")): + if cond: + out = R.add(x, x) + else: + out = R.multiply(x, x) + return out + + if_else = func.body.blocks[0].bindings[0].value + assert isinstance(if_else.cond, relax.Var) + tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Tensor([], "bool")) + + +def test_prim_value_as_branch_condition(): + """In addition to scalar tensor, can use R.Prim condition""" + + @R.function + def func(cond: R.Prim("bool"), x: R.Tensor((1,), "float32")): + if cond: + out = R.add(x, x) + else: + out = R.multiply(x, x) + return out + + if_else = func.body.blocks[0].bindings[0].value + assert isinstance(if_else.cond, relax.Var) + tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Prim("bool")) + + +def test_computed_prim_value_as_branch_condition(): + """The R.Prim condition may be computed within the function""" + + @R.function + def func(x: R.Tensor(["N"], "float32")): + N = T.int64() + if R.prim_value(N % 16 == 0): + out = R.call_pure_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) + else: + out = R.call_pure_packed("slow_non_vectorized_impl", x, sinfo_args=[x.struct_info]) + return out + + N = func.params[0].struct_info.shape[0] + if_else = func.body.blocks[0].bindings[0].value + assert isinstance(if_else.cond, relax.PrimValue) + tvm.ir.assert_structural_equal(N % 16 == 0, if_else.cond.value) + tvm.ir.assert_structural_equal(if_else.cond.struct_info, R.Prim(value=N % 16 == 0)) + + +def test_tir_expr_as_branch_condition(): + """Syntactic sugar, wrap PrimExpr as PrimValue""" + + @R.function(private=True) + def sugared(x: R.Tensor(["N"], "float32")): + N = T.int64() + if N % 16 == 0: + out = R.call_pure_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) + else: + out = R.call_pure_packed("slow_non_vectorized_impl", x, sinfo_args=[x.struct_info]) + return out + + @R.function(private=True) + def unsugared(x: R.Tensor(["N"], "float32")): + N = T.int64() + if R.prim_value(N % 16 == 0): + out = R.call_pure_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) + else: + out = R.call_pure_packed("slow_non_vectorized_impl", x, sinfo_args=[x.struct_info]) + return out + + tvm.ir.assert_structural_equal(unsugared, sugared) + + +def test_scalar_tensor_as_assert_condition(): + """Branch condition can be 0-d tensor""" + + @R.function(pure=False) + def func(cond: R.Tensor([], "bool"), x: R.Tensor((1,), "float32")): + _ = R.assert_op(cond) + out = R.add(x, x) + return out + + assert_op = func.body.blocks[0].bindings[0].value + condition = assert_op.args[0] + assert isinstance(condition, relax.Var) + tvm.ir.assert_structural_equal(condition.struct_info, R.Tensor([], "bool")) + + +def test_prim_value_as_assert_condition(): + """In addition to scalar tensor, can use R.Prim condition""" + + @R.function(pure=False) + def func(cond: R.Prim("bool"), x: R.Tensor((1,), "float32")): + _ = R.assert_op(cond) + out = R.add(x, x) + return out + + assert_op = func.body.blocks[0].bindings[0].value + condition = assert_op.args[0] + assert isinstance(condition, relax.Var) + tvm.ir.assert_structural_equal(condition.struct_info, R.Prim("bool")) + + +def test_computed_prim_value_as_assert_condition(): + """The R.Prim condition may be computed within the function""" + + @R.function(pure=False) + def func(x: R.Tensor(["N"], "float32")): + N = T.int64() + _ = R.assert_op(R.prim_value(N % 16 == 0)) + out = R.call_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) + return out + + N = func.params[0].struct_info.shape[0] + assert_op = func.body.blocks[0].bindings[0].value + condition = assert_op.args[0] + assert isinstance(condition, relax.PrimValue) + tvm.ir.assert_structural_equal(N % 16 == 0, condition.value) + tvm.ir.assert_structural_equal(condition.struct_info, R.Prim(value=N % 16 == 0)) + + +def test_tir_expr_as_assert_condition(): + """Syntactic sugar, wrap PrimExpr as PrimValue""" + + @R.function(pure=False, private=True) + def sugared(x: R.Tensor(["N"], "float32")): + N = T.int64() + _ = R.assert_op(N % 16 == 0) + out = R.call_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) + return out + + @R.function(pure=False, private=True) + def unsugared(x: R.Tensor(["N"], "float32")): + N = T.int64() + _ = R.assert_op(R.prim_value(N % 16 == 0)) + out = R.call_packed("fast_vectorized_impl", x, sinfo_args=[x.struct_info]) + return out + + tvm.ir.assert_structural_equal(unsugared, sugared) + + def test_erase_to_well_defined_removes_internal_vars(): @R.function def foo(x: R.Tensor): @@ -1664,9 +1807,9 @@ def test_context_aware_parsing(): class Module: @T.prim_func def add( - X: T.Buffer(T.int64(8), "float32"), + X: T.Buffer([T.int64(2), T.int64(4)], "float32"), Y: T.Buffer((), "float32"), - Z: T.Buffer(T.int64(8), "float32"), + Z: T.Buffer([T.int64(2), T.int64(4)], "float32"), ): T.evaluate(0) diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py index 21e192955b93..9a4817f5fd8a 100644 --- a/tests/python/relax/test_vm_codegen_tir.py +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -72,7 +72,7 @@ def shape_func(H: T.Buffer(T.int64(4), "int64")): H[T.int64(0)] = H[T.int64(0)] + T.int64(1) @R.function(pure=False) - def foo(x: R.Tensor): + def foo(x: R.Tensor([4], "int64")): R.func_attr({"global_symbol": "foo"}) _ = Before.shape_func(x) return x diff --git a/tests/python/tir-base/test_tir_specialize.py b/tests/python/tir-base/test_tir_specialize.py index 042288723376..cead775e97cd 100644 --- a/tests/python/tir-base/test_tir_specialize.py +++ b/tests/python/tir-base/test_tir_specialize.py @@ -330,12 +330,11 @@ def expected(A_data: T.handle("float32")): tvm.ir.assert_structural_equal(expected, after) -def test_specialization_removes_struct_info(): - """Reset struct info in specialization +def test_specialization_updates_struct_info(): + """Update struct info in specialization - While a PrimFunc usually doesn't have a `relax.StructInfo`, the - field can be populated in some edge cases. If that PrimFunc is - specialized, the struct info should be reset. + A PrimFunc may have a `relax.StructInfo`. If that PrimFunc is + specialized, the struct info should be updated. """ @T.prim_func(private=True) @@ -346,24 +345,20 @@ def before(n: T.int32) -> T.int32: def expected() -> T.int32: T.ret(50) - sinfo = tvm.relax.FuncStructInfo( + sinfo_before = tvm.relax.FuncStructInfo( [tvm.relax.PrimStructInfo("int32")], tvm.relax.PrimStructInfo("int32") ) - tvm.relax.expr._update_struct_info(before, sinfo) + tvm.ir.assert_structural_equal(before.struct_info, sinfo_before) + + sinfo_expected = tvm.relax.FuncStructInfo([], tvm.relax.PrimStructInfo("int32")) + tvm.ir.assert_structural_equal(expected.struct_info, sinfo_expected) n = before.params[0] param_map = {n: 5} after = before.specialize(param_map) - tvm.ir.assert_structural_equal(expected, after) - assert before.struct_info is not None - - # PrimFuncs do not expose the `struct_info_` field. Checking the - # `struct_info` field when it isn't set raises an exception. This - # is the desired behavior, since the struct info before - # specialization is no longer valid. - with pytest.raises(tvm.TVMError): - after.struct_info + tvm.ir.assert_structural_equal(after, expected) + tvm.ir.assert_structural_equal(after.struct_info, sinfo_expected) if __name__ == "__main__": diff --git a/tests/python/tvmscript/test_tvmscript_parser_tir.py b/tests/python/tvmscript/test_tvmscript_parser_tir.py index 074603681f34..465ffa5cb602 100644 --- a/tests/python/tvmscript/test_tvmscript_parser_tir.py +++ b/tests/python/tvmscript/test_tvmscript_parser_tir.py @@ -340,5 +340,114 @@ def func(A: T.Buffer((128, 128)), B: T.Buffer((128, 128))): assert loop_j.thread_binding.var.dtype == "int32" +def test_inferred_sinfo_with_prim_args(): + """A PrimFunc may have inferred StructInfo""" + + @T.prim_func + def func(M: T.int32, N: T.int32) -> T.int32: + T.ret(M * N) + + expected = tvm.relax.FuncStructInfo( + [ + tvm.relax.PrimStructInfo("int32"), + tvm.relax.PrimStructInfo("int32"), + ], + tvm.relax.PrimStructInfo("int32"), + purity=True, + ) + tvm.ir.assert_structural_equal(func.struct_info, expected) + + +def test_inferred_sinfo_with_buffer_args(): + """PrimFunc buffer arguments are inferred as R.Tensor""" + + @T.prim_func + def func(A: T.Buffer([16, 16], "float32"), B: T.Buffer([256], "int32")) -> T.float32: + T.ret(T.float32(42.0)) + + expected = tvm.relax.FuncStructInfo( + [ + tvm.relax.TensorStructInfo([16, 16], "float32"), + tvm.relax.TensorStructInfo([256], "int32"), + ], + tvm.relax.PrimStructInfo("float32"), + purity=True, + ) + tvm.ir.assert_structural_equal(func.struct_info, expected) + + +def test_inferred_sinfo_with_internal_allocation(): + """A pure function may still write to internal allocations. + + Whether a function writes to internal allocations is not a visible + effect, and does not impact the purity of a function. + """ + + @T.prim_func + def func(A: T.Buffer([16, 16], "float32")) -> T.float32: + Sum = T.decl_buffer([], "float32") + Sum[()] = 0.0 + for i, j in T.grid(16, 16): + Sum[()] = Sum[()] + A[i, j] + + T.ret(Sum[()]) + + expected = tvm.relax.FuncStructInfo( + [ + tvm.relax.TensorStructInfo([16, 16], "float32"), + ], + tvm.relax.PrimStructInfo("float32"), + purity=True, + ) + tvm.ir.assert_structural_equal(func.struct_info, expected) + + +def test_inferred_sinfo_with_output_buffer(): + """A pure function may not write to an argument buffer + + If an argument buffer is written to, the function must be impure. + """ + + @T.prim_func + def func(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] + + expected = tvm.relax.FuncStructInfo( + [ + tvm.relax.TensorStructInfo([16], "float32"), + tvm.relax.TensorStructInfo([16], "float32"), + ], + tvm.relax.TupleStructInfo([]), + purity=False, + ) + tvm.ir.assert_structural_equal(func.struct_info, expected) + + +def test_inferred_sinfo_with_dynamic_buffer(): + """The inferred StructInfo may contain dynamic shapes""" + + @T.prim_func + def func(a_handle: T.handle, b_handle: T.handle): + M = T.int64() + N = T.int64() + A = T.match_buffer(a_handle, [M, N], "float32") + B = T.match_buffer(b_handle, [M * N], "float32") + for i, j in T.grid(M, N): + B[i * N + j] = A[i, j] + + M = tvm.tir.Var("M", "int64") + N = tvm.tir.Var("N", "int64") + expected = tvm.relax.FuncStructInfo( + [ + tvm.relax.TensorStructInfo([M, N], "float32"), + tvm.relax.TensorStructInfo([M * N], "float32"), + ], + tvm.relax.TupleStructInfo([]), + purity=False, + ) + tvm.ir.assert_structural_equal(func.struct_info, expected) + + if __name__ == "__main__": tvm.testing.main() From 95663ab7de9da9981cb390ded1e1ce83ff0e13bd Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 26 Mar 2024 12:18:24 -0500 Subject: [PATCH 4/4] Lint fix --- tests/python/relax/test_analysis_well_formed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relax/test_analysis_well_formed.py b/tests/python/relax/test_analysis_well_formed.py index fe171d9606c9..7deddfd28eb9 100644 --- a/tests/python/relax/test_analysis_well_formed.py +++ b/tests/python/relax/test_analysis_well_formed.py @@ -656,6 +656,7 @@ def subroutine(A: R.Tensor([16, 32], "float32"), B: R.Tensor([32, 64], "float32" assert rx.analysis.well_formed(Module["main"]) assert rx.analysis.well_formed(Module["subroutine"]) + def test_pass_dltensor_arg_to_tir(): """Relax may pass R.Tensor as DLTensor