From 1a5cb724f1643164cb4ea63efcd7a2bd06ff085b Mon Sep 17 00:00:00 2001 From: Jon Date: Fri, 25 Oct 2019 10:32:31 -0700 Subject: [PATCH 1/5] Add support for Any op --- docs/frontend/tensorflow.rst | 1 + include/tvm/expr_operator.h | 7 +++ python/tvm/relay/frontend/tensorflow.py | 1 + python/tvm/relay/op/_reduce.py | 1 + python/tvm/relay/op/reduce.py | 52 +++++++++++++++++++ src/lang/expr_operator.cc | 10 ++++ src/relay/op/tensor/reduce.cc | 37 +++++++++++++ .../frontend/tensorflow/test_forward.py | 12 ++++- tests/python/relay/test_op_level4.py | 3 +- topi/include/topi/reduction.h | 21 ++++++++ topi/python/topi/reduction.py | 25 +++++++++ topi/src/topi.cc | 5 ++ topi/tests/python/test_topi_reduce.py | 24 +++++++++ 13 files changed, 196 insertions(+), 3 deletions(-) diff --git a/docs/frontend/tensorflow.rst b/docs/frontend/tensorflow.rst index 827f5d637988..878288840354 100644 --- a/docs/frontend/tensorflow.rst +++ b/docs/frontend/tensorflow.rst @@ -116,6 +116,7 @@ Supported Ops - Abs - Add - All +- Any - ArgMax - ArgMin - AvgPool diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 007ae58ad4ba..a7164682eb1e 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -519,6 +519,13 @@ TVM_DLL Expr sum(Expr source, Array axis); */ TVM_DLL Expr all(Expr source, Array axis); +/*! + * \brief logical Or of of source expression over axis + * \param source The source expression. + * \param axis List of iteration variables that will be used for reduction. + */ +TVM_DLL Expr any(Expr source, Array axis); + /*! * \brief max of of source expression over axis * \param source The source expression. diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index bfa3431ba29e..1e73485794cc 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1342,6 +1342,7 @@ def _impl(inputs, attr, params): 'Abs' : AttrCvt('abs'), 'Add' : _elemwise('add'), 'All' : _reduce('all'), + 'Any' : _reduce('any'), 'ArgMax' : _argx(_op.argmax, 'argmax'), 'ArgMin' : _argx(_op.argmin, 'argmin'), 'Assert' : _assert(), diff --git a/python/tvm/relay/op/_reduce.py b/python/tvm/relay/op/_reduce.py index 845ec4b9ba87..06d0d66bdfb0 100644 --- a/python/tvm/relay/op/_reduce.py +++ b/python/tvm/relay/op/_reduce.py @@ -31,6 +31,7 @@ def _schedule_reduce(_, outs, target): _reg.register_schedule("argmin", _schedule_reduce) _reg.register_schedule("sum", _schedule_reduce) _reg.register_schedule("all", _schedule_reduce) +_reg.register_schedule("any", _schedule_reduce) _reg.register_schedule("max", _schedule_reduce) _reg.register_schedule("min", _schedule_reduce) _reg.register_schedule("prod", _schedule_reduce) diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 49193fd4b5c6..baf896e6bc9a 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -166,6 +166,58 @@ def all(data, axis=None, keepdims=False, exclude=False): return _make.all(data, axis, keepdims, exclude) +def any(data, axis=None, keepdims=False, exclude=False): + """Computes the logical OR of boolean array elements over given axes. + + Parameters + ---------- + data : relay.Expr + The input boolean tensor + + axis : None or int or tuple of int + Axis or axes along which a sum is performed. The default, axis=None, + will sum all of the elements of the input array. If axis is + negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. With this option, the result will broadcast + correctly against the input array. + + exclude : bool + If `exclude` is true, reduction will be performed on the axes that are + NOT in axis instead. + + Returns + ------- + result : relay.Expr + The computed result. + + Examples + -------- + .. code-block:: python + + data = relay.Constant(tvm.nd.array([[[ True, True, True], + [ True, True, True], + [False, True, False]], + [[ True, False, False], + [ True, True, False], + [False, True, True]]])) + + relay.any(data, axis=1) + # [[True, True, True], + # [True, True, True]] + + relay.any(data, axis=0) + # [[ True, True, True], + # [ True, True, True], + # [False, True, True]] + + """ + axis = [axis] if isinstance(axis, int) else axis + return _make.any(data, axis, keepdims, exclude) + + def max(data, axis=None, keepdims=False, exclude=False): """ Computes the max of array elements over given axes. diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 9c9100b1902e..220d4378cc97 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -486,6 +486,16 @@ Expr all(Expr source, Array rdom) { return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } +Expr any(Expr source, Array rdom) { + CHECK(source.type().is_bool()); + Var x("x", source.type()), y("y", source.type()); + Expr result = ir::Or::make(x, y); + Expr identity_element = make_const(source.type(), false); + ir::CommReducer combiner = + ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); +} + Expr max(Expr source, Array rdom) { Var x("x", source.type()), y("y", source.type()); Expr result = ir::Max::make(x, y); diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 51714bd9f756..63524bc4e81d 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -420,6 +420,43 @@ Example:: .set_attr("TOpPattern", kCommReduce); +Array AnyCompute(const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + return ReduceCompute(attrs, inputs, out_type, target, topi::any); +} + + +RELAY_REGISTER_REDUCE_OP("any") +.describe(R"code(Computes the logical OR of boolean array elements over given axes. + +Example:: + + data = [[[ True, True, True], + [ True, True, True], + [False, True, False]], + [[ True, False, False], + [ True, True, False], + [False, True, True]]] + + any(data, axis=1) + [[True, True, True], + [True, True, True]] + + any(data, axis=0) + [[ True, True, True], + [ True, True, True], + [False, True, True]] + +)code" TVM_ADD_FILELINE) +.set_attrs_type() +.set_support_level(4) +.add_type_rel("Reduce", ReduceRel) +.set_attr("FTVMCompute", AnyCompute) +.set_attr("TOpPattern", kCommReduce); + + Array MaxCompute(const Attrs& attrs, const Array& inputs, const Type& out_type, diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 11c6a7befca6..88787efdba16 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2198,7 +2198,7 @@ def check_size(ishape): check_size((10,)) ####################################################################### -# All, Max, Min +# All, Any, Max, Min # ------------- def test_forward_reduce_all(): """Test the All operator.""" @@ -2208,6 +2208,14 @@ def test_forward_reduce_all(): tf.reduce_all(in_data, name="all") compare_tf_with_tvm([np_data], ['in_data:0'], 'all:0') +def test_forward_reduce_any(): + """Test the Any operator.""" + np_data = np.random.choice([True, False], size=(5, 7, 11)) + tf.reset_default_graph() + in_data = tf.placeholder(tf.bool, (5, 7, 11), name="in_data") + tf.reduce_any(in_data, name="any") + compare_tf_with_tvm([np_data], ['in_data:0'], 'any:0') + def test_forward_reduce_max(): def check_max(ishape, axis, keepdims, dtype): tf.reset_default_graph() @@ -2432,7 +2440,7 @@ def test_forward_one_hot(): test_forward_mean() test_forward_reduce_prod() test_forward_reduce_all() - test_forward_reduce_max() + test_forward_reduce_any() test_forward_reduce_min() # General diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index c34dddfd0fd7..6a8a678bfda3 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -145,7 +145,7 @@ def test_where(): def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"): test_func = funcs[0] ref_func = funcs[1] - dtype = "bool" if ref_func in [np.all] else dtype + dtype = "bool" if ref_func in [np.all, np.any] else dtype x = relay.var("x", relay.TensorType(data, dtype)) z = test_func(x, axis, keepdims, exclude) @@ -207,6 +207,7 @@ def _wrapper(data, axis=None, keepdims=False): [relay.std, np.std], [relay.prod, np.prod], [relay.all, np.all], + [relay.any, np.any], [relay.argmin, _with_keepdims(np.argmin)], [relay.argmax, _with_keepdims(np.argmax)]]: verify_reduce(func, (d1, d2, d3, d4), None, False, False, ()) diff --git a/topi/include/topi/reduction.h b/topi/include/topi/reduction.h index 14dec7765151..b7036770aa4a 100644 --- a/topi/include/topi/reduction.h +++ b/topi/include/topi/reduction.h @@ -390,6 +390,27 @@ inline Tensor all(const Tensor& data, return CommReduce(data, axis, tvm::all, keepdims, atleast1d); } +/*! +* \brief Creates an operation that computes the logical OR of elements +* over a given axis +* +* \param data The input boolean tensor +* \param axis The axes to reduce. If axis is empty, the operation will +* perform logical OR over all elements of the array. +* \param keepdims If this is set to true, the axes which are reduced are +* left in the result as dimensions with size one. This enables the result +* to broadcast correctly against the input array. +* \param atleast1d Whether the output need to be atleast1d. +* +* \return A Tensor whose op member is the all operation +*/ +inline Tensor any(const Tensor& data, + const Array& axis, + bool keepdims = false, + bool atleast1d = false) { + return CommReduce(data, axis, tvm::any, keepdims, atleast1d); +} + /*! * \brief Creates an operation that finds the minimum of elements over * a given axis. diff --git a/topi/python/topi/reduction.py b/topi/python/topi/reduction.py index 5079bf474deb..7c4e059d8334 100644 --- a/topi/python/topi/reduction.py +++ b/topi/python/topi/reduction.py @@ -90,6 +90,31 @@ def all(data, axis=None, keepdims=False): return cpp.all(data, axis, keepdims) +def any(data, axis=None, keepdims=False): + """Logical OR of array elements over a given axis or a list of axes + + Parameters + ---------- + data : tvm.Tensor + The input tvm boolean tensor + + axis : None or int or tuple of int + Axis or axes along which a logical OR is performed. + The default, axis=None, will perform logical OR over all elements of the input array. + If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + Returns + ------- + ret : tvm.Tensor + """ + return cpp.any(data, axis, keepdims) + + def max(data, axis=None, keepdims=False): """Maximum of array elements over a given axis or a list of axes diff --git a/topi/src/topi.cc b/topi/src/topi.cc index a0700bffa7e3..01fc5983617c 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -300,6 +300,11 @@ TVM_REGISTER_GLOBAL("topi.all") *rv = topi::all(args[0], ArrayOrInt(args[1]), args[2]); }); +TVM_REGISTER_GLOBAL("topi.any") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]); + }); + /* Ops from transform.h */ TVM_REGISTER_GLOBAL("topi.expand_dims") .set_body([](TVMArgs args, TVMRetValue *rv) { diff --git a/topi/tests/python/test_topi_reduce.py b/topi/tests/python/test_topi_reduce.py index 6e6470dad588..d266cfc6ceb5 100644 --- a/topi/tests/python/test_topi_reduce.py +++ b/topi/tests/python/test_topi_reduce.py @@ -52,6 +52,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32") B = topi.sum(A1, axis=axis, keepdims=keepdims) elif type == "all": B = topi.all(A, axis=axis, keepdims=keepdims) + elif type == "any": + B = topi.any(A, axis=axis, keepdims=keepdims) elif type == "max": B = topi.max(A1, axis=axis, keepdims=keepdims) elif type == "min": @@ -86,6 +88,8 @@ def check_device(device): out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims) elif type == "all" and dtype == 'bool': out_npy = in_npy_map.all(axis=axis, keepdims=keepdims) + elif type == "any" and dtype == "bool": + out_npy = in_npy_map.any(axis=axis, keepdims=keepdims) elif type == "max": out_npy = in_npy_map.max(axis=axis, keepdims=keepdims) elif type == "min": @@ -173,6 +177,26 @@ def test_reduce_map(): keepdims=True, type="sum", dtype="float64") + verify_reduce_map_ele(in_shape=(2, 3), + axis=None, + keepdims=True, + type="any", + dtype="bool") + verify_reduce_map_ele(in_shape=(32, 128, 24), + axis=None, + keepdims=True, + type="any", + dtype="bool") + verify_reduce_map_ele(in_shape=(1, 4, 7), + axis=1, + keepdims=True, + type="any", + dtype="bool") + verify_reduce_map_ele(in_shape=(128, 24, 128, 24), + axis=2, + keepdims=False, + type="any", + dtype="bool") if __name__ == "__main__": test_reduce_map() From 0ce672daa29d02750e9a70792f849485f6b3bbd8 Mon Sep 17 00:00:00 2001 From: Jon Date: Fri, 25 Oct 2019 10:55:35 -0700 Subject: [PATCH 2/5] Support ONNX frontend --- python/tvm/relay/frontend/onnx.py | 10 ++++- tests/python/frontend/onnx/test_forward.py | 48 ++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a7f787484b2c..72eb5a4406d1 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -922,6 +922,13 @@ class Erf(OnnxOpConverter): def _impl_v1(cls, inputs, attr, params): return _op.erf(inputs[0]) +class Or(Elemwise): + """ Operator converter for Or. + """ + @classmethod + def _impl_v7(cls, inputs, attr, params): + return _op.logical_or(inputs[0], inputs[1]) + # compatible operators that do NOT require any conversion. _identity_list = [] @@ -1042,7 +1049,8 @@ def _get_convert_map(opset): 'Not': Not.get_converter(opset), 'And': And.get_converter(opset), 'Tile': Tile.get_converter(opset), - 'Erf': Erf.get_converter(opset) + 'Erf': Erf.get_converter(opset), + 'Or': Or.get_converter(opset) } diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 16e717401174..84164fb5f7c0 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1300,6 +1300,53 @@ def test_erf(): verify_erf(x, z) +def verify_or(indata, dtype): + x = indata[0].astype(dtype) + y = indata[1].astype(dtype) + outdata = np.logical_or(x, y) + + node = helper.make_node('Or', inputs=['in1', 'in2'], outputs=['out'], ) + + graph = helper.make_graph([node], + 'or_test', + inputs=[helper.make_tensor_value_info("in1", TensorProto.BOOL, list(x.shape)), + helper.make_tensor_value_info("in2", TensorProto.BOOL, list(y.shape))], + outputs=[helper.make_tensor_value_info("out", TensorProto.BOOL, list(outdata.shape))]) + + model = helper.make_model(graph, producer_name='or_test') + + for target, ctx in ctx_list(): + tvm_out = get_tvm_output(model, [x, y], target, ctx, outdata.shape) + tvm.testing.assert_allclose(outdata, tvm_out) + + +def test_or(): + # 2d + x = (np.random.randn(3, 4) > 0) + y = (np.random.randn(3, 4) > 0) + verify_or(indata=[x, y], dtype=bool) + + # 3d + x = (np.random.randn(3, 4, 5) > 0) + y = (np.random.randn(3, 4, 5) > 0) + verify_or(indata=[x, y], dtype=bool) + + # 4d + x = (np.random.randn(3, 4, 5, 6) > 0) + y = (np.random.randn(3, 4, 5, 6) > 0) + verify_or(indata=[x, y], dtype=bool) + + # 3d vs 1d + x = (np.random.randn(3, 4, 5) > 0) + y = (np.random.randn(5) > 0) + verify_or(indata=[x, y], dtype=bool) + + # 3d vs 2d + x = (np.random.randn(3, 4, 5) > 0) + y = (np.random.randn(4, 5) > 0) + verify_or(indata=[x, y], dtype=bool) + + if __name__ == '__main__': test_flatten() test_reshape() @@ -1347,3 +1394,4 @@ def test_erf(): test_and() test_tile() test_erf() + test_or() From 5b8e5077ade98c584d09a29736633c504eeac9a6 Mon Sep 17 00:00:00 2001 From: Jon Date: Fri, 25 Oct 2019 12:57:34 -0700 Subject: [PATCH 3/5] Add doc --- docs/api/python/topi.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 3483668a5b08..0e203c176711 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -91,6 +91,7 @@ List of operators topi.greater_equal topi.less_equal topi.all + topi.any topi.logical_and topi.logical_or topi.logical_not @@ -151,6 +152,7 @@ topi .. autofunction:: topi.full .. autofunction:: topi.full_like .. autofunction:: topi.all +.. autofunction:: topi.any .. autofunction:: topi.max .. autofunction:: topi.sum .. autofunction:: topi.min From 95ac44854ff1599cdc8d7bb8e4fbf46d0ec57f15 Mon Sep 17 00:00:00 2001 From: Unknown Date: Sat, 26 Oct 2019 13:14:59 -0700 Subject: [PATCH 4/5] Add to relay docs --- docs/langref/relay_op.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 57325b53d974..db741206caa6 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -137,6 +137,7 @@ This level enables additional math and transform operators. tvm.relay.less tvm.relay.less_equal tvm.relay.all + tvm.relay.any tvm.relay.logical_and tvm.relay.logical_or tvm.relay.logical_not @@ -300,6 +301,7 @@ Level 4 Definitions .. autofunction:: tvm.relay.less .. autofunction:: tvm.relay.less_equal .. autofunction:: tvm.relay.all +.. autofunction:: tvm.relay.any .. autofunction:: tvm.relay.logical_and .. autofunction:: tvm.relay.logical_or .. autofunction:: tvm.relay.logical_not From 26f2a07b6d0604432cb5ed8969c2ab95535732b6 Mon Sep 17 00:00:00 2001 From: Jon Date: Mon, 28 Oct 2019 13:12:17 -0700 Subject: [PATCH 5/5] Dummy change to retrigger CI --- python/tvm/relay/frontend/onnx.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 53f3a97cb9e6..0ff62297d8b1 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -936,7 +936,6 @@ class Or(Elemwise): def _impl_v7(cls, inputs, attr, params): return _op.logical_or(inputs[0], inputs[1]) - # compatible operators that do NOT require any conversion. _identity_list = []