From e3b9c93c593287997c374177419d9a80692dbc17 Mon Sep 17 00:00:00 2001 From: tqchen Date: Tue, 15 Jan 2019 11:23:05 -0800 Subject: [PATCH] [OP] Fix reduce op problem when axis=None --- src/relay/op/tensor/reduce.cc | 7 ++++--- tests/python/relay/test_op_level4.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 95c26c3ab7e4..18817e8e4b6d 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -101,6 +101,7 @@ inline std::vector GetReduceAxes(const uint32_t indim, // Get axis under exclude condition. Array GetExcludeAxes(size_t indim, const Array& inaxis) { + CHECK(inaxis.defined()) << "Cannot set exclude when axis=None"; std::vector axis_flag(indim, true); for (auto i : inaxis) { int64_t axis = i->value; @@ -137,9 +138,9 @@ Array ReduceCompute(const Attrs& attrs, auto axes = param->axis; if (param->exclude) { axes = GetExcludeAxes(inputs[0]->shape.size(), param->axis); - } - if (axes.size() == 0) { - return { topi::identity(inputs[0]) }; + if (axes.size() == 0) { + return { topi::identity(inputs[0]) }; + } } return { f(inputs[0], axes, param->keepdims, false) }; } diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index db478ff251c5..45d6d36fdc20 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -135,7 +135,7 @@ def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32") out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype assert zz.checked_type == relay.ty.TensorType(output, out_type) - if all(isinstance(v, tvm.expr.Var) == 1 for v in data) or len(output) == 0: + if all(isinstance(v, tvm.expr.Var) == 1 for v in data): return func = relay.Function([x], z) @@ -187,7 +187,7 @@ def _wrapper(data, axis=None, keepdims=False): verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4)) verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4)) verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ()) - verify_reduce(func, (4, 4, 3), None, False, True, ()) + verify_reduce(func, (4, 4, 3), None, False, False, ()) verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,)) verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,)) verify_reduce(func, (128, 24, 128), (0, 2), False, False, (24,))