diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index d60a222ac265..402fa5515431 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -262,6 +262,17 @@ TVM_DLL Pass VerifyMemory(); */ TVM_DLL Pass VerifyGPUCode(Map constraints); +/*! + * \brief Statically check TIR code for out of bounds array access. + * + * This analysis is conservative: it will only raise errors if it can prove + * that out of bounds access occurs. Cases that are uncertain do not raise + * errors. + * + * \returns The pass. + */ +TVM_DLL Pass OOBChecker(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 13674daa2413..545404171309 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -320,3 +320,14 @@ def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool: Whether it is a well-formed TIR function. """ return _ffi_api.VerifyWellFormed(func, assert_mode) # type: ignore # pylint: disable=no-member + + +def OOBChecker(): + """Detect out of bounds memory access in arrays. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.OOBChecker() # type: ignore diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 6d48ad1ed151..584bbe8f04ea 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -436,11 +436,11 @@ class IntervalSetEvaluator : public ExprFunctor { int64_t vstride = stride.Eval()->value; if (vstride > 0) { return Combine(analyzer_, base, - IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1)), + IntervalSet(make_zero(t), make_const(t, vstride * (op->lanes - 1))), op->dtype); } else { return Combine(analyzer_, base, - IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)), + IntervalSet(make_const(t, vstride * (op->lanes - 1)), make_zero(t)), op->dtype); } } diff --git a/src/ir/transform.cc b/src/ir/transform.cc index d945278abc72..77ea942a0bb9 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -431,7 +431,7 @@ Pass GetPass(const String& pass_name) { // pass } else if ((f = Registry::Get("relay._transform." + pass_name))) { } - ICHECK(f != nullptr) << "Cannot use " << pass_name << "to create the pass"; + ICHECK(f != nullptr) << "Cannot use " << pass_name << " to create the pass"; return (*f)(); } diff --git a/src/tir/analysis/oob_checker.cc b/src/tir/analysis/oob_checker.cc new file mode 100644 index 000000000000..a3d3501a9aae --- /dev/null +++ b/src/tir/analysis/oob_checker.cc @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Out of bounds array access static analyzer. + */ + +#include + +#include "../../arith/ir_visitor_with_analyzer.h" +#include "../../printer/text_printer.h" +#include "../schedule/error.h" + +namespace tvm { +namespace tir { +namespace transform { +struct OOBLocation { + Buffer buf; + size_t dimension; + ObjectRef index; + arith::IntSet index_bounds; + arith::IntSet shape_bounds; +}; + +class OOBError : public ScheduleError { + public: + OOBError(IRModule mod, std::vector locations) : mod_(mod), locations_(locations) {} + String FastErrorString() const final { return "Out of bound memory access"; } + + String DetailRenderTemplate() const final { + std::stringstream s; + for (const auto& oob : locations_) { + s << "Out of bounds memory access on buffer " << oob.buf->name << " dimension " + << oob.dimension << "."; + s << " index " << oob.index << " with bounds [" << oob.index_bounds.min() << ", " + << oob.index_bounds.max() << "] is outside the range [0, " << oob.shape_bounds.min() + << "]."; + s << "\n"; + } + return s.str(); + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { + std::vector locs; + for (auto loc : locations_) { + locs.push_back(loc.index); + } + return locs; + } + + private: + IRModule mod_; + std::vector locations_; +}; +class OOBCheckerVisitor final : public arith::IRVisitorWithAnalyzer { + using IRVisitorWithAnalyzer::VisitExpr_; + using IRVisitorWithAnalyzer::VisitStmt_; + + public: + void VisitStmt_(const BufferStoreNode* node) final { + for (size_t i = 0; i < node->buffer->shape.size(); i++) { + CheckBounds(node, i); + } + IRVisitorWithAnalyzer::VisitStmt_(node); + } + void VisitExpr_(const BufferLoadNode* node) final { + for (size_t i = 0; i < node->buffer->shape.size(); i++) { + CheckBounds(node, i); + } + IRVisitorWithAnalyzer::VisitExpr_(node); + } + + template + void CheckBounds(const T* node, size_t i) { + auto ind_bounds = analyzer_.int_set(node->indices[i]); + auto shape_bounds = analyzer_.int_set(node->buffer->shape[i]); + // We would expect that + // `analyzer_.CanProve(node->indices[i] < 0 || node->indices[i] >= node->buffer->shape[i])` + // would be the way to check if any out of bounds access occurs here, but `CanProve` checks if + // the statement is true for all possible values (universal quantification). For a mix of in + // bounds and out of bounds access, no out of bounds access would be reported. We instead want + // to check if there is any value for which the access is out of bounds (existential + // quantification). + // An solution would be to check that the index is in bounds for every possible value. This + // has the problem that some valid access patterns maybe be valid but not provably valid. We + // prefer that this analysis is conservative and only shows errors that are provable. This leads + // us to the following check: are the bounds of the index outside the bounds of the shape. + if (analyzer_.CanProve(ind_bounds.max() >= shape_bounds.min()) || + analyzer_.CanProve(ind_bounds.min() < 0)) { + errors.push_back({node->buffer, i, node->indices[i], ind_bounds, shape_bounds}); + } + } + + std::vector errors; +}; + +transform::Pass OOBChecker() { + auto pass_func = [=](tir::PrimFunc func, IRModule mod, transform::PassContext ctx) { + OOBCheckerVisitor checker; + checker(func->body); + if (checker.errors.size() > 0) { + // mod doesn't contain our function, so we construct a new mod with out function + IRModule func_mod({{GlobalVar("main"), func}}); + LOG(FATAL) << OOBError(func_mod, checker.errors).RenderReport("Out of bounds checker"); + } + return func; + }; + return transform::CreatePrimFuncPass(pass_func, 0, "tir.analysis.OOBChecker", {}); +} + +TVM_REGISTER_GLOBAL("tir.analysis.OOBChecker").set_body_typed(OOBChecker); +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index 4ce5a97bb5d3..32e5c2455a85 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -46,7 +46,9 @@ String ScheduleError::RenderReport(const String& primitive) const { runtime::TypedPackedFunc( [&loc_obj_to_name](const Stmt& expr) -> std::string { auto it = loc_obj_to_name.find(Downcast(expr)); - if (it == loc_obj_to_name.end()) return ""; + if (it == loc_obj_to_name.end()) { + return ""; + } return it->second; }); diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 74b53442ec27..2302d0ed54f2 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -50,7 +50,7 @@ def test_vector(): lanes = 2 s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, stride, lanes)) assert s.min_value.value == base - assert s.max_value.value == base + stride * lanes - 1 + assert s.max_value.value == base + stride * (lanes - 1) def test_add_sub(): diff --git a/tests/python/unittest/test_tir_analysis_oob.py b/tests/python/unittest/test_tir_analysis_oob.py new file mode 100644 index 000000000000..f910ca503be2 --- /dev/null +++ b/tests/python/unittest/test_tir_analysis_oob.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm.script import tir as T + + +@T.prim_func +def bad_load(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]): + B[0, 0] = A[2, 2] + + +@T.prim_func +def bad_load_loop(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]): + for i in range(3): + B[i, 0] = A[i, 2] + + +@T.prim_func +def bad_store(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]): + B[0, 3] = A[1, 2] + + +@T.prim_func +def bad_store_loop(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]): + for i in range(3): + B[0, i] = A[1, i] + + +@T.prim_func +def unknown_bounds(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]): + N = T.var("int32") + for i in range(3): + B[0, N] = A[1, i] + + +def test_oob_load(): + with pytest.raises(tvm.tir.ScheduleError) as err: + tvm.tir.analysis.OOBChecker()(tvm.IRModule.from_expr(bad_load)) + assert "buffer A" in err.value.args[0] + + with pytest.raises(tvm.tir.ScheduleError) as err: + tvm.tir.analysis.OOBChecker()(tvm.IRModule.from_expr(bad_load_loop)) + assert "buffer A" in err.value.args[0] + + +def test_oob_store(): + with pytest.raises(tvm.tir.ScheduleError) as err: + tvm.tir.analysis.OOBChecker()(tvm.IRModule.from_expr(bad_store)) + assert "buffer B" in err.value.args[0] + + with pytest.raises(tvm.tir.ScheduleError) as err: + tvm.tir.analysis.OOBChecker()(tvm.IRModule.from_expr(bad_store_loop)) + assert "buffer B" in err.value.args[0] + + +def test_unknown_bounds(): + # This should not return an error as we can't probe that N goes out of bounds + tvm.tir.analysis.OOBChecker()(tvm.IRModule.from_expr(unknown_bounds)) + + +if __name__ == "__main__": + tvm.testing.main()