diff --git a/include/tvm/build_module.h b/include/tvm/build_module.h index 7114a4550331..a83288ce3662 100644 --- a/include/tvm/build_module.h +++ b/include/tvm/build_module.h @@ -229,6 +229,9 @@ class BuildConfigNode : public Node { /*! \brief Whether to disable loop vectorization. */ bool disable_vectorize = false; + /*! \brief Whether to disable assert stmt generation. */ + bool disable_assert = false; + void VisitAttrs(AttrVisitor* v) { v->Visit("data_alignment", &data_alignment); v->Visit("offset_factor", &offset_factor); @@ -244,6 +247,7 @@ class BuildConfigNode : public Node { v->Visit("instrument_bound_checkers", &instrument_bound_checkers); v->Visit("disable_select_rewriting", &disable_select_rewriting); v->Visit("disable_vectorize", &disable_vectorize); + v->Visit("disable_assert", &disable_assert); } static constexpr const char* _type_key = "BuildConfig"; diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 76d7d61f1e3d..5c5c4bb2f452 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -563,6 +563,13 @@ LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); */ LoweredFunc InferFragment(LoweredFunc f); +/*! + * \brief skip assert stmt generation + * \param f The function to be transformed. + * \return Transformed function. + */ +LoweredFunc SkipAssert(LoweredFunc f); + /*! * \brief Verify if memory accesses are legal for a specific target device type. * diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index 217318ebfa84..f96e28323595 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -144,7 +144,8 @@ class BuildConfig(NodeBase): "dump_pass_ir": False, "instrument_bound_checkers": False, "disable_select_rewriting": False, - "disable_vectorize": False + "disable_vectorize": False, + "disable_assert": False } _dump_ir = DumpIR() diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 3f279f8772df..ac991d4bfea3 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -672,6 +672,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", "; p->stream << "disable_select_rewriting=" << op->disable_select_rewriting; p->stream << "disable_vectorize=" << op->disable_vectorize; + p->stream << "disable_assert=" << op->disable_assert; p->stream << ")"; }); diff --git a/src/codegen/codegen.cc b/src/codegen/codegen.cc index ed9484b211b0..4ea37ba7317b 100644 --- a/src/codegen/codegen.cc +++ b/src/codegen/codegen.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -40,12 +41,21 @@ runtime::Module Build(const Array& funcs, if (pos != std::string::npos) { mode = mode.substr(0, pos); } + Array transformed_funcs; + for (const auto& x : funcs) { + if (BuildConfig::Current()->disable_assert) { + auto func = ir::SkipAssert(x); + transformed_funcs.push_back(func); + } + } std::string build_f_name = "codegen.build_" + mode; // the build function. const PackedFunc* bf = runtime::Registry::Get(build_f_name); CHECK(bf != nullptr) << "Target " << target << " is not enabled"; - runtime::Module m = (*bf)(funcs, target); + runtime::Module m = transformed_funcs.empty() ? + (*bf)(funcs, target) : + (*bf)(transformed_funcs, target); return m; } diff --git a/src/pass/skip_assert.cc b/src/pass/skip_assert.cc new file mode 100644 index 000000000000..5f310a61dfe3 --- /dev/null +++ b/src/pass/skip_assert.cc @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace ir { + +class AssertSkipper : public IRMutator { + public: + Stmt Mutate_(const AssertStmt* op, const Stmt& s) final { + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + return op->body; + } +}; + +Stmt SkipAssert(Stmt stmt) { + return AssertSkipper().Mutate(stmt); +} + +LoweredFunc SkipAssert(LoweredFunc f) { + auto n = make_node(*f.operator->()); + n->body = SkipAssert(f->body); + return LoweredFunc(n); +} + +} // namespace ir +} // namespace tvm