diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 5fc42392c337..fc326c18730e 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -720,6 +720,15 @@ TVM_DLL const Op& texture2d_load(); */ TVM_DLL const Op& mem_copy(); +/*! + * \brief Provide a true statement that can be used for simplifications + * + * Compile-time representation of known constraints about function + * inputs. This assumption is removed when lowering, and does not + * occur in codegen. + */ +TVM_DLL const Op& assume(); + /*! \brief The kind of structure field info used in intrinsic */ enum TVMStructFieldKind : int { // array head address diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index 440e1ca77d36..627b89086a8c 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -240,6 +240,17 @@ def store(var, index, value, predicate=True, span=None): super().__init__(store, stmt=True) +@register +class AssumeIntrin(Intrin): + def __init__(self): + def assume(constraint, span): + return tvm.tir.Evaluate( + tvm.tir.call_intrin("bool", "tir.assume", constraint, span=span) + ) + + super().__init__(assume, stmt=True) + + @register def comm_reducer(lambda_io, identities, span): """Create a CommReducer from lambda inputs/outputs and the identities""" diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 6cc7b2e1f885..d63c65dfddde 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -253,6 +253,17 @@ def RemoveNoOp(): return _ffi_api.RemoveNoOp() # type: ignore +def RemoveAssume(): + """Remove all instances of builtin::assume + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.RemoveAssume() # type: ignore + + def BF16Legalize(): """Legalize bf16 typed Ops. Runs BF16Promote, BF16CastElimination and BF16TypeLowering diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index aaebc7409f29..f2abf5c78d73 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1181,6 +1181,14 @@ Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) { } Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) { + if (auto* call = op->value.as()) { + if (call->op.same_as(builtin::assume())) { + Doc doc; + doc << tir_prefix_ << ".assume(" << Print(call->args[0]) << ")"; + return doc; + } + } + Doc doc; doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")"; return doc; diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 1871a3d7bf70..860f98dd1430 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -288,6 +288,10 @@ TIR_DEFINE_BUILTIN_FUNC(texture2d_load) TIR_DEFINE_BUILTIN_FUNC(mem_copy).set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(assume) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kEmbedInfo)) + .set_num_inputs(1); + } // namespace builtin } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/remove_assume.cc b/src/tir/transforms/remove_assume.cc new file mode 100644 index 000000000000..928bcf02bc1b --- /dev/null +++ b/src/tir/transforms/remove_assume.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file remove_store_undef.cc + * \brief Remove stores of tir::builtin::undef + */ +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { + +// Remove any builtin::assume calls +class AssumeRemover : public StmtExprMutator { + public: + using Parent = StmtExprMutator; + + Stmt VisitStmt_(const EvaluateNode* op) final { + if (auto* call = op->value.as()) { + if (call->op.same_as(builtin::assume())) { + return Evaluate(0); + } + } + return StmtExprMutator::VisitStmt_(op); + } +}; + +namespace transform { +Pass RemoveAssumeInternal() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = AssumeRemover()(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.RemoveAssumeInternal", {}); +} + +Pass RemoveAssume() { + return Sequential({RemoveAssumeInternal(), RemoveNoOp()}, "tir.RemoveAssume"); +} + +TVM_REGISTER_GLOBAL("tir.transform.RemoveAssume").set_body_typed(RemoveAssume); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_remove_assume.py b/tests/python/unittest/test_tir_transform_remove_assume.py new file mode 100644 index 000000000000..4223e40e3f2a --- /dev/null +++ b/tests/python/unittest/test_tir_transform_remove_assume.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm.script import tir as T +from tvm import TVMError + + +class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): + @tvm.testing.fixture + def transform(self): + return tvm.tir.transform.RemoveAssume() + + +class TestRemoveAssume(BaseBeforeAfter): + """Remove any instance of T.assume""" + + def before(A: T.Buffer[1, "int32"]): + T.assume(A[0] == 5) + A[0] = 10 + + def expected(A: T.Buffer[1, "int32"]): + A[0] = 10 + + +class TestRemoveAssumeLoop(BaseBeforeAfter): + """Loops containing only T.assume should be removed""" + + def before(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + T.assume(A[i] == 0) + + for i in T.serial(16): + A[i] = 10 + + def expected(A: T.Buffer[16, "int32"]): + for i in T.serial(16): + A[i] = 10 + + +if __name__ == "__main__": + tvm.testing.main()