From ad15fb1d00fb473fa2adce6f247f867bfedfc875 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Sun, 28 Feb 2021 23:14:24 -0800 Subject: [PATCH 1/2] Check for being out of range on the negative side too --- src/relay/op/nn/nn.cc | 4 ++-- tests/python/relay/test_op_level1.py | 17 +++++++++-------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 97460ba4a98b..18e9c3235295 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -61,10 +61,10 @@ bool BiasAddRel(const Array& types, int num_inputs, const Attrs& attrs, if (axis < 0) { axis = data->shape.size() + axis; } - if (axis >= static_cast(data->shape.size())) { + if (axis >= static_cast(data->shape.size()) || axis < 0) { reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) << "The axis in bias_add must be in range for the shape; " - << "attempted to access index " << axis << " of " + << "attempted to access index " << param->axis << " of " << PrettyPrint(data->shape)); return false; } diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index ea5dd6948b11..d9ae3b23b81a 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -202,14 +202,15 @@ def test_bias_add(): def test_bias_add_type_failure(): - # the axis is out of range - try: - b_add = relay.nn.bias_add(relay.const(1), relay.const(2), axis=0) - run_infer_type(b_add) - except tvm._ffi.base.TVMError: - pass - else: - assert False + def assert_failure(expr): + try: + run_infer_type(expr) + except tvm._ffi.base.TVMError: + return + else: + assert False + for axis in (0, -1, -3, 1): + assert_failure(relay.nn.bias_add(relay.const(1), relay.const(2), axis=axis)) def test_expand_dims_infer_type(): From 5a6c25343058922dd42fe7aac105bd33e6ce9fd8 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Sun, 28 Feb 2021 23:21:29 -0800 Subject: [PATCH 2/2] Linting fix --- tests/python/relay/test_op_level1.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index d9ae3b23b81a..dfd350486c3b 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -209,6 +209,7 @@ def assert_failure(expr): return else: assert False + for axis in (0, -1, -3, 1): assert_failure(relay.nn.bias_add(relay.const(1), relay.const(2), axis=axis))