diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index 3e3d94c614c3..97460ba4a98b 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -61,8 +61,13 @@ bool BiasAddRel(const Array& types, int num_inputs, const Attrs& attrs, if (axis < 0) { axis = data->shape.size() + axis; } - ICHECK_LE(axis, static_cast(data->shape.size())) - << "axis " << param->axis << " is out of range"; + if (axis >= static_cast(data->shape.size())) { + reporter->GetDiagCtx().EmitFatal(Diagnostic::Error(reporter->GetSpan()) + << "The axis in bias_add must be in range for the shape; " + << "attempted to access index " << axis << " of " + << PrettyPrint(data->shape)); + return false; + } // assign output type reporter->Assign(types[1], TensorType({data->shape[axis]}, data->dtype)); diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 54d04da5e092..ea5dd6948b11 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -201,6 +201,17 @@ def test_bias_add(): np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=rtol) +def test_bias_add_type_failure(): + # the axis is out of range + try: + b_add = relay.nn.bias_add(relay.const(1), relay.const(2), axis=0) + run_infer_type(b_add) + except tvm._ffi.base.TVMError: + pass + else: + assert False + + def test_expand_dims_infer_type(): for dtype in ["float16", "float32"]: n, t, d = te.size_var("n"), te.size_var("t"), 100 @@ -484,6 +495,7 @@ def test_bitserial_dense(): if __name__ == "__main__": test_concatenate() test_bias_add() + test_bias_add_type_failure() test_unary_op() test_binary_op() test_expand_dims_infer_type()