From d8dc61eeb5712fa15441274d6f7b35b251c1c564 Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 25 Oct 2020 17:14:33 +0900 Subject: [PATCH 01/13] where type rel with broadcast --- src/relay/op/tensor/transform.cc | 31 +++++++++---------------------- src/relay/op/type_relations.h | 2 ++ 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 68582b4918fe..44c8fe8e668e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -45,6 +45,7 @@ #include "../../transforms/pattern_utils.h" #include "../make_op.h" #include "../op_common.h" +#include "../type_relations.h" namespace tvm { namespace relay { @@ -1685,30 +1686,16 @@ bool WhereRel(const Array& types, int num_inputs, const Attrs& attrs, return false; } - const auto& cond_shape = condition->shape; - const auto& x_shape = x->shape; - const auto& y_shape = y->shape; - ICHECK(x_shape.size() == y_shape.size()) << "x and y must have the same size"; + CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " + << y->dtype; - if (cond_shape.size() != x_shape.size()) { - ICHECK_EQ(cond_shape.size(), 1) << "Shape of condition " << condition->shape - << " must be either equal to x or has dimension of 1."; - } - for (size_t i = 0; i < x_shape.size(); i++) { - ICHECK(reporter->AssertEQ(x_shape[i], y_shape[i])) - << "x and y must have the same shape: " << x_shape << " vs " << y_shape; + auto b_ty = Downcast( + ConcreteBroadcast(GetRef(x), GetRef(y), x->dtype)); + auto ret_ty = ConcreteBroadcast(GetRef(condition), b_ty, b_ty->dtype); + + LOG(INFO) << "where broadcast type:" << ret_ty; + reporter->Assign(types[3], ret_ty); - if (i < cond_shape.size()) { - ICHECK(reporter->AssertEQ(cond_shape[i], x_shape[i])) - << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape; - } - } - if (x_shape.size() == 0) { - // if x and y are scalar, the condition shape becomes the output shape - reporter->Assign(types[3], TensorType(cond_shape, x->dtype)); - } else { - reporter->Assign(types[3], TensorType(x_shape, x->dtype)); - } return true; } diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index 5ab8b121ae9d..a915d659f9a7 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -57,6 +57,8 @@ bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, bool BroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); +Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype); + /*! * \brief The broadcast type relation, implements the broadcasting * rule over the two input types producing the broadcasted type. From ff201f174b58b144fc7869a4241ea7e517625d0c Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 25 Oct 2020 17:15:15 +0900 Subject: [PATCH 02/13] add tests for where with broadcast --- include/tvm/topi/transform.h | 71 ++++++++--------------- src/relay/op/tensor/transform.cc | 11 ++-- src/relay/op/type_relations.cc | 2 +- src/relay/op/type_relations.h | 2 +- tests/python/relay/test_op_level4.py | 86 ++++++++++++++++++---------- 5 files changed, 88 insertions(+), 84 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index fa27faf18f15..47f0be2ee37a 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -39,6 +39,8 @@ #include #include +#include "detail/broadcast.h" + namespace tvm { namespace topi { @@ -887,53 +889,30 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string */ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, std::string name = "T_where", std::string tag = kBroadcast) { - ICHECK_EQ(x->shape.size(), y->shape.size()) - << "x and y must have the same shape.Got different number of dimension: " << x->shape.size() - << " vs " << y->shape.size(); - ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " - << y->dtype; + CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " + << y->dtype; + auto get_out_shape = [&]() { + auto bh1 = detail::BroadcastShape(x->shape, y->shape); + Array common_shape1(bh1.common_shape.begin(), bh1.common_shape.end()); + auto bh2 = detail::BroadcastShape(condition->shape, common_shape1); + Array common_shape2(bh2.common_shape.begin(), bh2.common_shape.end()); + return common_shape2; + }; - if (x->shape.size() == 0) { - return compute( - condition->shape, - [&](const Array& indices) { - PrimExpr cond; - if (condition->shape.size() == 0) { - cond = condition(); - } else { - Array condition_idx{indices[0]}; - cond = condition(condition_idx); - } - return tvm::tir::Select(cond != 0, x(), y()); - }, - name, tag); - } else if (condition->shape.size() != 1) { - ICHECK_EQ(condition->shape.size(), x->shape.size()) - << "condition array must be either have the same shape as x or to be a " - "1-D array.Got different number of dimension: " - << condition->shape.size() << " vs " << x->shape.size(); - return compute( - x->shape, - [&](const Array& indices) { - return tvm::tir::Select(condition(indices) != 0, x(indices), y(indices)); - }, - name, tag); - } else { - int64_t cond_first_dim = topi::GetConstInt(condition->shape[0]); - int64_t x_first_dim = topi::GetConstInt(x->shape[0]); - if (cond_first_dim > 0 && x_first_dim > 0) { - ICHECK_EQ(cond_first_dim, x_first_dim) - << "If condition is 1-D, the first dimension must be the same as x: " << cond_first_dim - << " vs " << x_first_dim; - } - return compute( - x->shape, - [&](const Array& indices) { - Array condition_idx{indices[0]}; - return tvm::tir::Select(condition(condition_idx) != 0, x(indices), y(indices)); - }, - name, tag); - } + auto oshape = get_out_shape(); + + auto c_bh = detail::BroadcastShape(condition->shape, oshape); + auto x_bh = detail::BroadcastShape(x->shape, oshape); + auto y_bh = detail::BroadcastShape(y->shape, oshape); + + auto select = [&](tvm::Array ovars) { + auto c = condition(InputIndexFromBroadcast(ovars, condition, c_bh.vars1, c_bh.all_vars)); + auto true_val = x(InputIndexFromBroadcast(ovars, x, x_bh.vars1, x_bh.all_vars)); + auto false_val = y(InputIndexFromBroadcast(ovars, y, y_bh.vars1, y_bh.all_vars)); + return tvm::tir::Select(c != 0, true_val, false_val); + }; + + return compute(oshape, select, name, tag); } /*! diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 44c8fe8e668e..4541f98b2ed5 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1689,13 +1689,14 @@ bool WhereRel(const Array& types, int num_inputs, const Attrs& attrs, CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " << y->dtype; - auto b_ty = Downcast( - ConcreteBroadcast(GetRef(x), GetRef(y), x->dtype)); - auto ret_ty = ConcreteBroadcast(GetRef(condition), b_ty, b_ty->dtype); + auto tensor_ty_condition = GetRef(condition); + auto tensor_ty_x = GetRef(x); + auto tensor_ty_y = GetRef(y); - LOG(INFO) << "where broadcast type:" << ret_ty; - reporter->Assign(types[3], ret_ty); + auto b_ty = ConcreteBroadcast(tensor_ty_x, tensor_ty_y, x->dtype); + auto ret_ty = ConcreteBroadcast(tensor_ty_condition, b_ty, b_ty->dtype); + reporter->Assign(types[3], ret_ty); return true; } diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index 3dc33c5022e0..7a3bfcb21ce6 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -64,7 +64,7 @@ bool EqualConstInt(const IndexExpr& lhs, int64_t value) { return false; } -Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { +TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype) { std::vector oshape; size_t ndim1 = t1->shape.size(); size_t ndim2 = t2->shape.size(); diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index a915d659f9a7..8e5803bef195 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -57,7 +57,7 @@ bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, bool BroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); -Type ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype); +TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype); /*! * \brief The broadcast type relation, implements the broadcasting diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index eafc743634d8..1657d4f13a34 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -152,35 +152,59 @@ def run(func, inputs, ref_res): op_res = intrp.evaluate(func)(*inputs) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - shape = (3, 4) dtype = "float32" - cond = relay.var("cond", relay.TensorType(shape, dtype)) - x = relay.var("x", relay.TensorType(shape, dtype)) - y = relay.var("y", relay.TensorType(shape, dtype)) - z = relay.where(cond, x, y) - zz = run_infer_type(z) - assert zz.checked_type == relay.TensorType(shape, dtype) + # shape = (3, 4) + # cond = relay.var("cond", relay.TensorType(shape, dtype)) + # x = relay.var("x", relay.TensorType(shape, dtype)) + # y = relay.var("y", relay.TensorType(shape, dtype)) + # z = relay.where(cond, x, y) + # zz = run_infer_type(z) + # assert zz.checked_type == relay.TensorType(shape, dtype) - func = relay.Function([cond, x, y], z) - condition = np.random.uniform(low=-1, high=1, size=shape).astype(dtype) - x = np.random.uniform(size=shape).astype(dtype) - y = np.random.uniform(size=shape).astype(dtype) - ref_res = np.where(condition, x, y) + # func = relay.Function([cond, x, y], z) + # condition = np.random.uniform(low=-1, high=1, size=shape).astype(dtype) + # x = np.random.uniform(size=shape).astype(dtype) + # y = np.random.uniform(size=shape).astype(dtype) + # ref_res = np.where(condition, x, y) - run(func, [condition, x, y], ref_res) + # run(func, [condition, x, y], ref_res) - x = relay.const(1) - y = relay.const(-1) - shape = (3,) - dtype = "float32" - cond = relay.var("cond", relay.TensorType(shape, "bool")) - z = relay.where(cond, x, y) + # x = relay.const(1) + # y = relay.const(-1) + # shape = (3,) + # dtype = "float32" + # cond = relay.var("cond", relay.TensorType(shape, "bool")) + # z = relay.where(cond, x, y) + + # func = relay.Function([cond], z) + # condition = np.array([1, 0, 1], dtype=np.bool) + # ref_res = np.where(condition, 1, -1) + + # run(func, [condition], ref_res) + + def verify(x_np, y_np, cond_np): + ref_res = np.where(cond_np, x_np, y_np) + + cond = relay.var("cond", relay.TensorType(cond_np.shape, "bool")) + x = relay.var("x", relay.TensorType(x_np.shape, dtype)) + y = relay.var("y", relay.TensorType(y_np.shape, dtype)) + z = relay.where(cond, x, y) + func = relay.Function([cond, x, y], z) + + run(func, [cond_np, x_np, y_np], ref_res) + + x_np = np.array([[1, 2], [3, 4]], dtype) + y_np = np.array([[5, 6], [7, 8]], dtype) + cond_np = np.array([[1], [0]], "bool") + + # verify(x_np, y_np, cond_np) + # verify(x_np, y_np, cond_np.T) - func = relay.Function([cond], z) - condition = np.array([1, 0, 1], dtype=np.bool) - ref_res = np.where(condition, 1, -1) + x_np = np.random.randn(1, 12, 8, 8).astype(dtype) + y_np = np.array(-1., dtype) + cond_np = np.random.randn(1, 1, 8, 8) > 0 - run(func, [condition], ref_res) + verify(x_np, y_np, cond_np) def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"): @@ -498,12 +522,12 @@ def verify(dshape, begin, end, strides, vshape, test_ref=True): if __name__ == "__main__": - test_strided_slice() - test_strided_set() - test_binary_op() - test_cmp_type() - test_binary_int_broadcast_1() - test_binary_int_broadcast_2() + # test_strided_slice() + # test_strided_set() + # test_binary_op() + # test_cmp_type() + # test_binary_int_broadcast_1() + # test_binary_int_broadcast_2() test_where() - test_reduce_functions() - test_mean_var_std() + # test_reduce_functions() + # test_mean_var_std() From d8ec53892f32b03777b7d6414a96b024f846c8da Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 25 Oct 2020 17:15:38 +0900 Subject: [PATCH 03/13] clean up tests --- tests/python/relay/test_op_level4.py | 78 ++++++++++++++-------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 1657d4f13a34..0cda849fba6e 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -152,56 +152,56 @@ def run(func, inputs, ref_res): op_res = intrp.evaluate(func)(*inputs) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - dtype = "float32" - # shape = (3, 4) - # cond = relay.var("cond", relay.TensorType(shape, dtype)) - # x = relay.var("x", relay.TensorType(shape, dtype)) - # y = relay.var("y", relay.TensorType(shape, dtype)) - # z = relay.where(cond, x, y) - # zz = run_infer_type(z) - # assert zz.checked_type == relay.TensorType(shape, dtype) - - # func = relay.Function([cond, x, y], z) - # condition = np.random.uniform(low=-1, high=1, size=shape).astype(dtype) - # x = np.random.uniform(size=shape).astype(dtype) - # y = np.random.uniform(size=shape).astype(dtype) - # ref_res = np.where(condition, x, y) - - # run(func, [condition, x, y], ref_res) - - # x = relay.const(1) - # y = relay.const(-1) - # shape = (3,) - # dtype = "float32" - # cond = relay.var("cond", relay.TensorType(shape, "bool")) - # z = relay.where(cond, x, y) - - # func = relay.Function([cond], z) - # condition = np.array([1, 0, 1], dtype=np.bool) - # ref_res = np.where(condition, 1, -1) - - # run(func, [condition], ref_res) - def verify(x_np, y_np, cond_np): ref_res = np.where(cond_np, x_np, y_np) + args = [] + args_np = [] + vs = [] + cond = relay.var("cond", relay.TensorType(cond_np.shape, "bool")) - x = relay.var("x", relay.TensorType(x_np.shape, dtype)) - y = relay.var("y", relay.TensorType(y_np.shape, dtype)) - z = relay.where(cond, x, y) - func = relay.Function([cond, x, y], z) - run(func, [cond_np, x_np, y_np], ref_res) + args.append(cond) + args_np.append(cond_np) + + for v_name, v_np in [("x", x_np), ("y", y_np)]: + if len(v_np.shape) == 0: + v = relay.const(v_np.item()) + else: + v = relay.var(v_name, relay.TensorType(v_np.shape, dtype)) + args.append(v) + args_np.append(v_np) + vs.append(v) + + z = relay.where(cond, vs[0], vs[1]) + + func = relay.Function(args, z) + + run(func, args_np, ref_res) + + dtype = "float32" + + x_np = np.random.uniform(size=(3, 4)).astype(dtype) + y_np = np.random.uniform(size=(3, 4)).astype(dtype) + cond_np = np.random.uniform(low=-1, high=1, size=(3, 4)) > 0 + + verify(x_np, y_np, cond_np) + + x_np = np.array(1.0, dtype) + y_np = np.array(-1.0, dtype) + cond_np = np.array([1, 0, 1], dtype=np.bool) + + verify(x_np, y_np, cond_np) x_np = np.array([[1, 2], [3, 4]], dtype) y_np = np.array([[5, 6], [7, 8]], dtype) - cond_np = np.array([[1], [0]], "bool") + cond_np = np.array([[1], [0]], dtype=np.bool) - # verify(x_np, y_np, cond_np) - # verify(x_np, y_np, cond_np.T) + verify(x_np, y_np, cond_np) + verify(x_np, y_np, cond_np.T) x_np = np.random.randn(1, 12, 8, 8).astype(dtype) - y_np = np.array(-1., dtype) + y_np = np.array(-1.0, dtype) cond_np = np.random.randn(1, 1, 8, 8) > 0 verify(x_np, y_np, cond_np) From f9b1d229bbba32c5c42de23582bc5e8fed046a74 Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 25 Oct 2020 17:19:08 +0900 Subject: [PATCH 04/13] uncomment other tests --- tests/python/relay/test_op_level4.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 0cda849fba6e..13b557b7eede 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -522,12 +522,12 @@ def verify(dshape, begin, end, strides, vshape, test_ref=True): if __name__ == "__main__": - # test_strided_slice() - # test_strided_set() - # test_binary_op() - # test_cmp_type() - # test_binary_int_broadcast_1() - # test_binary_int_broadcast_2() + test_strided_slice() + test_strided_set() + test_binary_op() + test_cmp_type() + test_binary_int_broadcast_1() + test_binary_int_broadcast_2() test_where() - # test_reduce_functions() - # test_mean_var_std() + test_reduce_functions() + test_mean_var_std() From 5d9811ace3ef1cf44bc595514e65da260cb70fd0 Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 25 Oct 2020 17:25:34 +0900 Subject: [PATCH 05/13] add more tests --- tests/python/relay/test_op_level4.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 13b557b7eede..ef363430a2eb 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -193,6 +193,12 @@ def verify(x_np, y_np, cond_np): verify(x_np, y_np, cond_np) + x_np = np.arange(10).astype(dtype) + y_np = 10 * x_np + cond_np = x_np < 5 + + verify(x_np, y_np, cond_np) + x_np = np.array([[1, 2], [3, 4]], dtype) y_np = np.array([[5, 6], [7, 8]], dtype) cond_np = np.array([[1], [0]], dtype=np.bool) @@ -206,6 +212,11 @@ def verify(x_np, y_np, cond_np): verify(x_np, y_np, cond_np) + x_np, y_np = np.ogrid[:3, :4] + cond_np = np.where(x_np < y_np, x_np, 10 + y_np).astype(np.bool) + + verify(x_np.astype(dtype), y_np.astype(dtype), cond_np) + def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"): test_func = funcs[0] From d377d8df4f7dd8cc7f473096bedc9426baa255ae Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 26 Oct 2020 08:01:58 +0900 Subject: [PATCH 06/13] update doc --- python/tvm/relay/op/transform.py | 17 +++++++++-------- src/relay/op/tensor/transform.cc | 15 ++++----------- src/relay/op/type_relations.h | 7 +++++++ 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 14ac454aec64..855fd9369c34 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -649,25 +649,26 @@ def where(condition, x, y): condition. .. note:: - The shape of condition, x, and y needs to be the same. + Shapes of condition, x, and y must be broadcastable to a common shape. + Semantics follow numpy where function + https://numpy.org/doc/stable/reference/generated/numpy.where.html Parameters ---------- condition : relay.Expr - The condition array. The n-th element in `y` is selected when the n-th - value in the `condition` array is zero. Otherwise, the corresponding - element from `x` will be picked. + Where True, yield x, otherwise yield y x : relay.Expr - The first array to be selected. + The first array or scalar to be selected. y : relay.Expr - The second array to be selected. + The second array or scalar to be selected. Returns ------- result : relay.Expr - The selected array. + The selected array. The output shape is the broadcasted shape from + condition, x, and y. Examples -------- @@ -678,7 +679,7 @@ def where(condition, x, y): condition = [[0, 1], [-1, 0]] relay.where(conditon, x, y) = [[5, 2], [3, 8]] - condition = [1, 0] + condition = [[1], [0]] relay.where(conditon, x, y) = [[1, 2], [7, 8]] """ return _make.where(condition, x, y) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 4541f98b2ed5..0f5c63cf688b 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1719,17 +1719,10 @@ Return the elements, either from x or y, depending on the condition. Given three ndarrays, condition, x, and y, return an ndarray with the elements from x or y, depending on the elements from condition are true or false. -x and y must have the same shape. If condition has the same shape as x, -each element in the output array is from x if the corresponding element -in the condition is true, and from y if false. -If condition does not have the same shape as x, it must be a 1D array whose -size is the same as x’s first dimension size. Each row of the output array -is from x’s row if the corresponding element from condition is true, and -from y’s row if false. - -When x and y are scalars, condition must be an 1D array. The output shape -is the same as condition's shape. +Shapes of condition, x, and y must be broadcastable to a common shape, which +is the output shape of this op. Semantics follow numpy where function. +https://numpy.org/doc/stable/reference/generated/numpy.where.html Note that all non-zero values are interpreted as True in condition. @@ -1741,7 +1734,7 @@ Examples:: where(cond, x, y) = [[5, 2], [3, 8]] - cond = [1, 0] + cond = [[1], [0]] where(cond, x, y) = [[1, 2], [7, 8]] cond = [0, 1] diff --git a/src/relay/op/type_relations.h b/src/relay/op/type_relations.h index 8e5803bef195..6d6d5f70c0c2 100644 --- a/src/relay/op/type_relations.h +++ b/src/relay/op/type_relations.h @@ -57,6 +57,13 @@ bool IdentityRel(const Array& types, int num_inputs, const Attrs& attrs, bool BroadcastRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter); +/*! + * \brief Determine the broadcasted shape from two input shapes + * \param t1 One of two Tensortype whose shapes are broadcasted + * \param t2 One of two Tensortype whose shapes are broadcasted + * \param output_dtype dtype of the output TensorType + * \return A TensorType whose shape is broadcasted from two input TensorType. + */ TensorType ConcreteBroadcast(const TensorType& t1, const TensorType& t2, DataType output_dtype); /*! From 1e9ca8b14425056751acf5a1893dddf44af63846 Mon Sep 17 00:00:00 2001 From: masa Date: Mon, 26 Oct 2020 10:37:58 +0900 Subject: [PATCH 07/13] CHECK -> ICHECK --- include/tvm/topi/transform.h | 4 ++-- src/relay/op/tensor/transform.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 47f0be2ee37a..b670755d97b7 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -889,8 +889,8 @@ inline Tensor take(const Tensor& a, const Tensor& indices, int axis, std::string */ inline Tensor where(const Tensor& condition, const Tensor& x, const Tensor& y, std::string name = "T_where", std::string tag = kBroadcast) { - CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " - << y->dtype; + ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " + << y->dtype; auto get_out_shape = [&]() { auto bh1 = detail::BroadcastShape(x->shape, y->shape); Array common_shape1(bh1.common_shape.begin(), bh1.common_shape.end()); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 0f5c63cf688b..c0af0876fccb 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1686,8 +1686,8 @@ bool WhereRel(const Array& types, int num_inputs, const Attrs& attrs, return false; } - CHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " - << y->dtype; + ICHECK_EQ(x->dtype, y->dtype) << "x and y must have the same dtype: " << x->dtype << " vs " + << y->dtype; auto tensor_ty_condition = GetRef(condition); auto tensor_ty_x = GetRef(x); From 5423c70152c088015a6e51c3c9805b12df2e1153 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 26 Oct 2020 15:22:11 +0900 Subject: [PATCH 08/13] add where any test --- include/tvm/topi/broadcast.h | 40 +++++++++++++++++++++++++++++++ python/tvm/relay/op/_transform.py | 6 +++-- python/tvm/topi/broadcast.py | 22 ++++++++++++++++- src/topi/broadcast.cc | 4 ++++ tests/python/relay/test_any.py | 31 ++++++++++++++++++++++++ 5 files changed, 100 insertions(+), 3 deletions(-) diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index f4f4f2ccb917..af6d01c0f3f0 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -69,6 +69,46 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, return tvm::te::compute(oshape, l, name, tag); } +inline tvm::te::Tensor broadcast_shape_tensors(const tvm::te::Tensor& shape_tensor1, + const tvm::te::Tensor& shape_tensor2, + std::string name = "T_broadcast_shape_tensors", + std::string tag = kBroadcast) { + const auto rank1 = detail::GetConstInt(shape_tensor1->shape[0]); + const auto rank2 = detail::GetConstInt(shape_tensor2->shape[0]); + const auto out_rank = std::max(rank1, rank2); + const tvm::PrimExpr one = tvm::cast(shape_tensor1->dtype, PrimExpr(1)); + + auto select_dim = [&](const tvm::te::Tensor& shape_tensor, int rank, + tvm::tir::Var index) -> PrimExpr { + if (rank < out_rank) { + // if the rank is smaller, dimension 1 is prepended according to + // the numpy broadcasting semantics. + return tvm::tir::Select(rank - (out_rank - index) < 0, one, + shape_tensor[rank - (out_rank - index)]); + } else { + // rank == out_rank, safe to index directly + return shape_tensor[index]; + } + }; + + auto func = [&](tvm::Array ovars) { + auto index = ovars[0]; + PrimExpr dim1 = select_dim(shape_tensor1, rank1, index); + PrimExpr dim2 = select_dim(shape_tensor2, rank2, index); + if (topi::detail::EqualCheck(one, dim1)) { + return dim2; + } else if (topi::detail::EqualCheck(one, dim2)) { + return dim1; + } + return tvm::max(dim1, dim2); + }; + + Array oshape; + oshape.push_back(PrimExpr(out_rank)); + + return tvm::te::compute(oshape, func, name, tag); +} + #define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \ inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \ diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 415529fdcb9a..6e57e8a539e9 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -817,6 +817,8 @@ def where_shape_func(attrs, inputs, _): """ cond_shape = inputs[0] x_shape = inputs[1] - out_shape = x_shape if x_shape.shape else cond_shape + y_shape = inputs[2] + bcast_shape = topi.broadcast.broadcast_shape_tensors(x_shape, y_shape) + out_shape = topi.broadcast.broadcast_shape_tensors(bcast_shape, cond_shape) - return [topi.math.identity(out_shape)] + return [out_shape] diff --git a/python/tvm/topi/broadcast.py b/python/tvm/topi/broadcast.py index 2b350ff817d9..7161887a4f56 100644 --- a/python/tvm/topi/broadcast.py +++ b/python/tvm/topi/broadcast.py @@ -22,7 +22,7 @@ def broadcast_to(data, shape): """Broadcast the src to the target shape - We follows the numpy broadcasting rule. + We follow the numpy broadcasting rule. See also https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html Parameters @@ -40,6 +40,26 @@ def broadcast_to(data, shape): return _cpp.broadcast_to(data, shape) +def broadcast_shape_tensors(shape_tensor1, shape_tensor2): + """ Compute a shape tensor whose values represents the broadcasted shape + of two input shape tensors + + Parameters + ---------- + shape_tensor1 : tvm.te.Tensor + One of input shape tensors + + shape_tensor2 : tvm.te.Tensor + One of input shape tensors + + Returns + ------- + ret : tvm.te.Tensor + A shape tensor whose values represents the broadcasted shape + """ + return _cpp.broadcast_shape_tensors(shape_tensor1, shape_tensor2) + + def add(lhs, rhs): """Addition with auto-broadcasting diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc index f6a28c7722af..46b2d69dbe7f 100644 --- a/src/topi/broadcast.cc +++ b/src/topi/broadcast.cc @@ -76,5 +76,9 @@ TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body([](TVMArgs args, TVMRetValue* *rv = broadcast_to(args[0], args[1]); }); +TVM_REGISTER_GLOBAL("topi.broadcast_shape_tensors").set_body([](TVMArgs args, TVMRetValue* rv) { + *rv = broadcast_shape_tensors(args[0], args[1]); +}); + } // namespace topi } // namespace tvm diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 872728514c3e..3f43fd99ba65 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1236,5 +1236,36 @@ def test_any_stack(): verify_any_stack(any_dims(4), (2, 1, 1, 4), 2, 2) +def verify_any_where(cond_shape, x_shape, y_shape, cond_np_shape, x_np_shape, y_np_shape): + dtype = "float32" + cond = relay.var("cond", shape=cond_shape, dtype="bool") + x = relay.var("x", shape=x_shape, dtype=dtype) + y = relay.var("y", shape=y_shape, dtype=dtype) + z = relay.where(cond, x, y) + mod = tvm.IRModule() + mod["main"] = relay.Function([cond, x, y], z) + + cond_np = np.random.randn(*cond_np_shape) > 0 + x_np = np.random.randn(*x_np_shape).astype(dtype) + y_np = np.random.randn(*y_np_shape).astype(dtype) + expected = np.where(cond_np, x_np, y_np) + + check_result([cond_np, x_np, y_np], mod, expected) + + +@tvm.testing.uses_gpu +def test_any_where(): + verify_any_where(any_dims(1), (5,), (5,), (5,), (5,), (5,)) + verify_any_where(any_dims(1), any_dims(1), (5,), (5,), (5,), (5,)) + verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (5,), (5,)) + verify_any_where((5,), any_dims(1), any_dims(1), (5,), (5,), (5,)) + + # where with broadcast + verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (1,), (5,)) + verify_any_where(any_dims(1), any_dims(2), any_dims(2), (5,), (5, 5), (5, 5)) + verify_any_where(any_dims(1), any_dims(1), any_dims(2), (5,), (5,), (5, 5)) + verify_any_where(any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4)) + + if __name__ == "__main__": pytest.main([__file__]) From 50eca1544db257a6ac9fc37daac413314f5db2c1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 26 Oct 2020 15:57:00 +0900 Subject: [PATCH 09/13] fix format --- python/tvm/topi/broadcast.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/broadcast.py b/python/tvm/topi/broadcast.py index 7161887a4f56..36d28feb8dcc 100644 --- a/python/tvm/topi/broadcast.py +++ b/python/tvm/topi/broadcast.py @@ -41,8 +41,8 @@ def broadcast_to(data, shape): def broadcast_shape_tensors(shape_tensor1, shape_tensor2): - """ Compute a shape tensor whose values represents the broadcasted shape - of two input shape tensors + """Compute a shape tensor whose values represents the broadcasted shape + of two input shape tensors Parameters ---------- From 2fffdec062563ce3a56cdbef8a41e889e26aad09 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 26 Oct 2020 17:07:51 +0900 Subject: [PATCH 10/13] remove useless detections for one --- include/tvm/topi/broadcast.h | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index af6d01c0f3f0..50ea369acc60 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -69,6 +69,7 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, return tvm::te::compute(oshape, l, name, tag); } +// This is used in the shape func of where op inline tvm::te::Tensor broadcast_shape_tensors(const tvm::te::Tensor& shape_tensor1, const tvm::te::Tensor& shape_tensor2, std::string name = "T_broadcast_shape_tensors", @@ -95,11 +96,6 @@ inline tvm::te::Tensor broadcast_shape_tensors(const tvm::te::Tensor& shape_tens auto index = ovars[0]; PrimExpr dim1 = select_dim(shape_tensor1, rank1, index); PrimExpr dim2 = select_dim(shape_tensor2, rank2, index); - if (topi::detail::EqualCheck(one, dim1)) { - return dim2; - } else if (topi::detail::EqualCheck(one, dim2)) { - return dim1; - } return tvm::max(dim1, dim2); }; From a073652ac1b745698c6eba3fbc51983fc84c0c9a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 26 Oct 2020 21:18:12 +0900 Subject: [PATCH 11/13] set manual seed --- tests/python/frontend/pytorch/test_lstm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/frontend/pytorch/test_lstm.py b/tests/python/frontend/pytorch/test_lstm.py index 39d78c70c0fb..1197990f54ba 100644 --- a/tests/python/frontend/pytorch/test_lstm.py +++ b/tests/python/frontend/pytorch/test_lstm.py @@ -277,6 +277,8 @@ def test_custom_lstm(): num_layers = 3 state_tensor_shape = (batch, hidden_size) + torch.manual_seed(1) + inp = torch.randn(seq_len, batch, input_size) input_shapes = [ From d8c5076aec16e32ea7f00d89af8eb82c0e0984a2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 27 Oct 2020 14:44:28 +0900 Subject: [PATCH 12/13] ported shape broadcast helper func to hybridscript --- python/tvm/relay/op/_transform.py | 32 +++++++++++++++++++++++++++++-- tests/python/relay/test_any.py | 17 ++++++++++++++-- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 6e57e8a539e9..fa5bfdf43a34 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -810,6 +810,33 @@ def stack_shape_func(attrs, inputs, _): return [_stack_shape_func(inputs[0], convert(axis), convert(len(inputs)))] +@script +def _broadcast_shape_tensors(shape_tensor1, shape_tensor2): + rank1 = shape_tensor1.shape[0] + rank2 = shape_tensor2.shape[0] + out_rank = max(rank1, rank2) + bcast_shape_tensor = output_tensor((out_rank,), "int64") + + for index in const_range(out_rank): + dim1 = int64(1) + dim2 = int64(1) + + if rank1 == out_rank: + dim1 = shape_tensor1[index] + elif rank1 - (out_rank - index) >= 0: + dim1 = shape_tensor1[rank1 - (out_rank - index)] + + if rank2 == out_rank: + dim2 = shape_tensor2[index] + elif rank2 - (out_rank - index) >= 0: + dim2 = shape_tensor2[rank2 - (out_rank - index)] + + assert dim1 == dim2 or dim1 == 1 or dim2 == 1, "Invalid broadcast shapes" + bcast_shape_tensor[index] = max(dim1, dim2) + + return bcast_shape_tensor + + @_reg.register_shape_func("where", False) def where_shape_func(attrs, inputs, _): """ @@ -818,7 +845,8 @@ def where_shape_func(attrs, inputs, _): cond_shape = inputs[0] x_shape = inputs[1] y_shape = inputs[2] - bcast_shape = topi.broadcast.broadcast_shape_tensors(x_shape, y_shape) - out_shape = topi.broadcast.broadcast_shape_tensors(bcast_shape, cond_shape) + + bcast_shape = _broadcast_shape_tensors(x_shape, y_shape) + out_shape = _broadcast_shape_tensors(bcast_shape, cond_shape) return [out_shape] diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 3f43fd99ba65..b1b068ebb32a 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1236,7 +1236,9 @@ def test_any_stack(): verify_any_stack(any_dims(4), (2, 1, 1, 4), 2, 2) -def verify_any_where(cond_shape, x_shape, y_shape, cond_np_shape, x_np_shape, y_np_shape): +def verify_any_where( + cond_shape, x_shape, y_shape, cond_np_shape, x_np_shape, y_np_shape, y_np_shape_invalid=None +): dtype = "float32" cond = relay.var("cond", shape=cond_shape, dtype="bool") x = relay.var("x", shape=x_shape, dtype=dtype) @@ -1252,6 +1254,15 @@ def verify_any_where(cond_shape, x_shape, y_shape, cond_np_shape, x_np_shape, y_ check_result([cond_np, x_np, y_np], mod, expected) + # verify invalid broadcasting check + if y_np_shape_invalid: + y_np_bad = np.random.randn(*y_np_shape_invalid).astype(dtype) + try: + check_result([cond_np, x_np, y_np_bad], mod, expected) + except tvm.error.TVMError as e: + error_msg = str(e).split("\n")[-1] + assert "Invalid broadcast shapes" in error_msg + @tvm.testing.uses_gpu def test_any_where(): @@ -1264,7 +1275,9 @@ def test_any_where(): verify_any_where(any_dims(1), any_dims(1), any_dims(1), (5,), (1,), (5,)) verify_any_where(any_dims(1), any_dims(2), any_dims(2), (5,), (5, 5), (5, 5)) verify_any_where(any_dims(1), any_dims(1), any_dims(2), (5,), (5,), (5, 5)) - verify_any_where(any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4)) + verify_any_where( + any_dims(2), any_dims(2), any_dims(2), (3, 4), (3, 1), (1, 4), y_np_shape_invalid=(2, 4) + ) if __name__ == "__main__": From fd4d9505345a643700486616c824108a9498336b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 27 Oct 2020 15:08:49 +0900 Subject: [PATCH 13/13] remove shape function helper from cpp --- include/tvm/topi/broadcast.h | 36 ------------------------------------ python/tvm/topi/broadcast.py | 22 +--------------------- src/topi/broadcast.cc | 4 ---- 3 files changed, 1 insertion(+), 61 deletions(-) diff --git a/include/tvm/topi/broadcast.h b/include/tvm/topi/broadcast.h index 50ea369acc60..f4f4f2ccb917 100644 --- a/include/tvm/topi/broadcast.h +++ b/include/tvm/topi/broadcast.h @@ -69,42 +69,6 @@ inline tvm::te::Tensor broadcast_to(const tvm::te::Tensor& t, return tvm::te::compute(oshape, l, name, tag); } -// This is used in the shape func of where op -inline tvm::te::Tensor broadcast_shape_tensors(const tvm::te::Tensor& shape_tensor1, - const tvm::te::Tensor& shape_tensor2, - std::string name = "T_broadcast_shape_tensors", - std::string tag = kBroadcast) { - const auto rank1 = detail::GetConstInt(shape_tensor1->shape[0]); - const auto rank2 = detail::GetConstInt(shape_tensor2->shape[0]); - const auto out_rank = std::max(rank1, rank2); - const tvm::PrimExpr one = tvm::cast(shape_tensor1->dtype, PrimExpr(1)); - - auto select_dim = [&](const tvm::te::Tensor& shape_tensor, int rank, - tvm::tir::Var index) -> PrimExpr { - if (rank < out_rank) { - // if the rank is smaller, dimension 1 is prepended according to - // the numpy broadcasting semantics. - return tvm::tir::Select(rank - (out_rank - index) < 0, one, - shape_tensor[rank - (out_rank - index)]); - } else { - // rank == out_rank, safe to index directly - return shape_tensor[index]; - } - }; - - auto func = [&](tvm::Array ovars) { - auto index = ovars[0]; - PrimExpr dim1 = select_dim(shape_tensor1, rank1, index); - PrimExpr dim2 = select_dim(shape_tensor2, rank2, index); - return tvm::max(dim1, dim2); - }; - - Array oshape; - oshape.push_back(PrimExpr(out_rank)); - - return tvm::te::compute(oshape, func, name, tag); -} - #define TOPI_DEFINE_BCAST_OP(Name, ComputeRule) \ inline tvm::PrimExpr Name(const tvm::PrimExpr& a, const tvm::PrimExpr& b) { ComputeRule; } \ inline tvm::te::Tensor Name(const tvm::te::Tensor& A, const tvm::te::Tensor& B, \ diff --git a/python/tvm/topi/broadcast.py b/python/tvm/topi/broadcast.py index 36d28feb8dcc..2b350ff817d9 100644 --- a/python/tvm/topi/broadcast.py +++ b/python/tvm/topi/broadcast.py @@ -22,7 +22,7 @@ def broadcast_to(data, shape): """Broadcast the src to the target shape - We follow the numpy broadcasting rule. + We follows the numpy broadcasting rule. See also https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html Parameters @@ -40,26 +40,6 @@ def broadcast_to(data, shape): return _cpp.broadcast_to(data, shape) -def broadcast_shape_tensors(shape_tensor1, shape_tensor2): - """Compute a shape tensor whose values represents the broadcasted shape - of two input shape tensors - - Parameters - ---------- - shape_tensor1 : tvm.te.Tensor - One of input shape tensors - - shape_tensor2 : tvm.te.Tensor - One of input shape tensors - - Returns - ------- - ret : tvm.te.Tensor - A shape tensor whose values represents the broadcasted shape - """ - return _cpp.broadcast_shape_tensors(shape_tensor1, shape_tensor2) - - def add(lhs, rhs): """Addition with auto-broadcasting diff --git a/src/topi/broadcast.cc b/src/topi/broadcast.cc index 46b2d69dbe7f..f6a28c7722af 100644 --- a/src/topi/broadcast.cc +++ b/src/topi/broadcast.cc @@ -76,9 +76,5 @@ TVM_REGISTER_GLOBAL("topi.broadcast_to").set_body([](TVMArgs args, TVMRetValue* *rv = broadcast_to(args[0], args[1]); }); -TVM_REGISTER_GLOBAL("topi.broadcast_shape_tensors").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = broadcast_shape_tensors(args[0], args[1]); -}); - } // namespace topi } // namespace tvm