Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,17 @@ TVM_DLL Pass VerifyMemory();
*/
TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);

/*!
* \brief Statically check TIR code for out of bounds array access.
*
* This analysis is conservative: it will only raise errors if it can prove
* that out of bounds access occurs. Cases that are uncertain do not raise
* errors.
*
* \returns The pass.
*/
TVM_DLL Pass OOBChecker();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,14 @@ def verify_well_formed(func: PrimFunc, assert_mode: bool = True) -> bool:
Whether it is a well-formed TIR function.
"""
return _ffi_api.VerifyWellFormed(func, assert_mode) # type: ignore # pylint: disable=no-member


def OOBChecker():
"""Detect out of bounds memory access in arrays.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.OOBChecker() # type: ignore
4 changes: 2 additions & 2 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,11 +436,11 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
int64_t vstride = stride.Eval()->value;
if (vstride > 0) {
return Combine<Add>(analyzer_, base,
IntervalSet(make_zero(t), make_const(t, vstride * op->lanes - 1)),
IntervalSet(make_zero(t), make_const(t, vstride * (op->lanes - 1))),
op->dtype);
} else {
return Combine<Add>(analyzer_, base,
IntervalSet(make_const(t, vstride * op->lanes + 1), make_zero(t)),
IntervalSet(make_const(t, vstride * (op->lanes - 1)), make_zero(t)),
op->dtype);
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ Pass GetPass(const String& pass_name) {
// pass
} else if ((f = Registry::Get("relay._transform." + pass_name))) {
}
ICHECK(f != nullptr) << "Cannot use " << pass_name << "to create the pass";
ICHECK(f != nullptr) << "Cannot use " << pass_name << " to create the pass";
return (*f)();
}

Expand Down
130 changes: 130 additions & 0 deletions src/tir/analysis/oob_checker.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Out of bounds array access static analyzer.
*/

#include <tvm/tir/transform.h>

#include "../../arith/ir_visitor_with_analyzer.h"
#include "../../printer/text_printer.h"
#include "../schedule/error.h"

namespace tvm {
namespace tir {
namespace transform {
struct OOBLocation {
Buffer buf;
size_t dimension;
ObjectRef index;
arith::IntSet index_bounds;
arith::IntSet shape_bounds;
};

class OOBError : public ScheduleError {
public:
OOBError(IRModule mod, std::vector<OOBLocation> locations) : mod_(mod), locations_(locations) {}
String FastErrorString() const final { return "Out of bound memory access"; }

String DetailRenderTemplate() const final {
std::stringstream s;
for (const auto& oob : locations_) {
s << "Out of bounds memory access on buffer " << oob.buf->name << " dimension "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It might be nice to have a test where you show the full rendered strings.

<< oob.dimension << ".";
s << " index " << oob.index << " with bounds [" << oob.index_bounds.min() << ", "
<< oob.index_bounds.max() << "] is outside the range [0, " << oob.shape_bounds.min()
<< "].";
s << "\n";
}
return s.str();
}
IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final {
std::vector<ObjectRef> locs;
for (auto loc : locations_) {
locs.push_back(loc.index);
}
return locs;
}

private:
IRModule mod_;
std::vector<OOBLocation> locations_;
};
class OOBCheckerVisitor final : public arith::IRVisitorWithAnalyzer {
using IRVisitorWithAnalyzer::VisitExpr_;
using IRVisitorWithAnalyzer::VisitStmt_;

public:
void VisitStmt_(const BufferStoreNode* node) final {
for (size_t i = 0; i < node->buffer->shape.size(); i++) {
CheckBounds(node, i);
}
IRVisitorWithAnalyzer::VisitStmt_(node);
}
void VisitExpr_(const BufferLoadNode* node) final {
for (size_t i = 0; i < node->buffer->shape.size(); i++) {
CheckBounds(node, i);
}
IRVisitorWithAnalyzer::VisitExpr_(node);
}

template <class T>
void CheckBounds(const T* node, size_t i) {
auto ind_bounds = analyzer_.int_set(node->indices[i]);
auto shape_bounds = analyzer_.int_set(node->buffer->shape[i]);
// We would expect that
// `analyzer_.CanProve(node->indices[i] < 0 || node->indices[i] >= node->buffer->shape[i])`
// would be the way to check if any out of bounds access occurs here, but `CanProve` checks if
// the statement is true for all possible values (universal quantification). For a mix of in
// bounds and out of bounds access, no out of bounds access would be reported. We instead want
// to check if there is any value for which the access is out of bounds (existential
// quantification).
// An solution would be to check that the index is in bounds for every possible value. This
// has the problem that some valid access patterns maybe be valid but not provably valid. We
// prefer that this analysis is conservative and only shows errors that are provable. This leads
// us to the following check: are the bounds of the index outside the bounds of the shape.
if (analyzer_.CanProve(ind_bounds.max() >= shape_bounds.min()) ||
analyzer_.CanProve(ind_bounds.min() < 0)) {
errors.push_back({node->buffer, i, node->indices[i], ind_bounds, shape_bounds});
}
}

std::vector<OOBLocation> errors;
};

transform::Pass OOBChecker() {
auto pass_func = [=](tir::PrimFunc func, IRModule mod, transform::PassContext ctx) {
OOBCheckerVisitor checker;
checker(func->body);
if (checker.errors.size() > 0) {
// mod doesn't contain our function, so we construct a new mod with out function
IRModule func_mod({{GlobalVar("main"), func}});
LOG(FATAL) << OOBError(func_mod, checker.errors).RenderReport("Out of bounds checker");
}
return func;
};
return transform::CreatePrimFuncPass(pass_func, 0, "tir.analysis.OOBChecker", {});
}

TVM_REGISTER_GLOBAL("tir.analysis.OOBChecker").set_body_typed(OOBChecker);
} // namespace transform
} // namespace tir
} // namespace tvm
4 changes: 3 additions & 1 deletion src/tir/schedule/error.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ String ScheduleError::RenderReport(const String& primitive) const {
runtime::TypedPackedFunc<std::string(Stmt)>(
[&loc_obj_to_name](const Stmt& expr) -> std::string {
auto it = loc_obj_to_name.find(Downcast<ObjectRef>(expr));
if (it == loc_obj_to_name.end()) return "";
if (it == loc_obj_to_name.end()) {
return "";
}
return it->second;
});

Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_arith_intset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_vector():
lanes = 2
s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, stride, lanes))
assert s.min_value.value == base
assert s.max_value.value == base + stride * lanes - 1
assert s.max_value.value == base + stride * (lanes - 1)


def test_add_sub():
Expand Down
78 changes: 78 additions & 0 deletions tests/python/unittest/test_tir_analysis_oob.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest

import tvm
from tvm.script import tir as T


@T.prim_func
def bad_load(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]):
B[0, 0] = A[2, 2]


@T.prim_func
def bad_load_loop(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]):
for i in range(3):
B[i, 0] = A[i, 2]


@T.prim_func
def bad_store(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]):
B[0, 3] = A[1, 2]


@T.prim_func
def bad_store_loop(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]):
for i in range(3):
B[0, i] = A[1, i]


@T.prim_func
def unknown_bounds(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2), "float32"]):
N = T.var("int32")
for i in range(3):
B[0, N] = A[1, i]


def test_oob_load():
with pytest.raises(tvm.tir.ScheduleError) as err:
tvm.tir.analysis.OOBChecker()(tvm.IRModule.from_expr(bad_load))
assert "buffer A" in err.value.args[0]

with pytest.raises(tvm.tir.ScheduleError) as err:
tvm.tir.analysis.OOBChecker()(tvm.IRModule.from_expr(bad_load_loop))
assert "buffer A" in err.value.args[0]


def test_oob_store():
with pytest.raises(tvm.tir.ScheduleError) as err:
tvm.tir.analysis.OOBChecker()(tvm.IRModule.from_expr(bad_store))
assert "buffer B" in err.value.args[0]

with pytest.raises(tvm.tir.ScheduleError) as err:
tvm.tir.analysis.OOBChecker()(tvm.IRModule.from_expr(bad_store_loop))
assert "buffer B" in err.value.args[0]


def test_unknown_bounds():
# This should not return an error as we can't probe that N goes out of bounds
tvm.tir.analysis.OOBChecker()(tvm.IRModule.from_expr(unknown_bounds))


if __name__ == "__main__":
tvm.testing.main()