From 393e7d23b3b45a57c3f923b5a2406cb7f406dd1e Mon Sep 17 00:00:00 2001 From: Siju Samuel Date: Thu, 22 Nov 2018 11:31:04 +0530 Subject: [PATCH] Relay reshape reshape_like compute and schedule --- python/tvm/relay/op/_transform.py | 8 +++++ src/relay/op/tensor/transform.cc | 18 +++++++++-- tests/python/relay/test_op_level3.py | 47 +++++++++++++++++++++++++++- 3 files changed, 70 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 01814e0f73e0..c1ae5a417455 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -10,3 +10,11 @@ # slice_like _reg.register_schedule("slice_like", schedule_injective) _reg.register_pattern("slice_like", OpPattern.INJECTIVE) + +# reshape +_reg.register_schedule("reshape", schedule_injective) +_reg.register_pattern("reshape", OpPattern.INJECTIVE) + +# reshape_like +_reg.register_schedule("reshape_like", schedule_injective) +_reg.register_pattern("reshape_like", OpPattern.INJECTIVE) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index a9e0a969fc5b..52363e8af92a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -376,7 +376,15 @@ Example:: .set_attrs_type_key("relay.attrs.ReshapeAttrs") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(3) -.add_type_rel("Reshape", ReshapeRel); +.add_type_rel("Reshape", ReshapeRel) +.set_attr("FTVMCompute", [](const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + const auto* param = attrs.as(); + CHECK(param != nullptr); + return Array{ topi::reshape(inputs[0], param->newshape) }; +}); /*! @@ -431,7 +439,13 @@ the input array into an output array with the same shape as the second input arr .add_argument("data", "Tensor", "The input tensor.") .add_argument("shape_like", "Tensor", "Shape tensor.") .set_support_level(3) -.add_type_rel("ReshapeLike", ReshapeLikeRel); +.add_type_rel("ReshapeLike", ReshapeLikeRel) +.set_attr("FTVMCompute", [](const Attrs& attrs, + const Array& inputs, + const Type& out_type, + const Target& target) { + return Array{ topi::reshape(inputs[0], inputs[1]->shape) }; +}); // Take diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 6f8fbd551293..47ea662eb076 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -107,8 +107,28 @@ def test_reshape_infer_type(): assert yy.checked_type == relay.TensorType( (n, t, 2000), "float32") +def test_reshape(): + def verify_reshape(shape, oshape): + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + ref_res = np.reshape(x_data, oshape) -def test_reshape_like(): + x = relay.var("x", relay.TensorType(shape, "float32")) + z = relay.reshape(x, newshape=ref_res.shape) + zz = relay.ir_pass.infer_type(z) + assert "newshape=" in z.astext() + assert zz.checked_type == relay.ty.TensorType(oshape, "float32") + + func = relay.Function([x], z) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + verify_reshape((2, 3, 4), (8, 3)) + verify_reshape((4, 7), (2, 7, 2)) + +def test_reshape_like_infer_type(): # concrete shape x = relay.var("x", relay.TensorType((1, 2, 3), "float32")) y = relay.var("y", relay.TensorType((1,6), "float32")) @@ -125,6 +145,29 @@ def test_reshape_like(): assert zz.checked_type == relay.TensorType((1, 8, 8), "float32") +def test_reshape_like(): + def verify_reshape_like(shape, oshape): + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + y_data = np.random.uniform(low=-1, high=1, size=oshape).astype("float32") + ref_res = np.reshape(x_data, y_data.shape) + + x = relay.var("x", relay.TensorType(shape, "float32")) + y = relay.var("x", relay.TensorType(oshape, "float32")) + z = relay.reshape_like(x, y) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32") + + func = relay.Function([x, y], z) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data, y_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + + verify_reshape_like((2, 3, 4), (1, 8, 3)) + verify_reshape_like((4, 7), (2, 7, 2)) + def test_take_infer_type(): def verify_take(dshape, indices_shape, oshape, axis=None): x = relay.var("x", relay.TensorType(dshape, "float32")) @@ -302,6 +345,8 @@ def test_infer_type_prelu(): test_clip() test_transpose_infer_type() test_reshape_infer_type() + test_reshape() + test_reshape_like_infer_type() test_reshape_like() test_take_infer_type() test_full()