From e18dbd3149bba37562448093ed23aeadb533305d Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 18 Dec 2020 10:33:09 -0800 Subject: [PATCH 1/4] Separate reshape and reverse_reshape. --- include/tvm/relay/attrs/transform.h | 4 -- src/relay/op/dyn/tensor/transform.cc | 1 - src/relay/op/tensor/transform.cc | 75 +++++++++++++++++++++------- 3 files changed, 57 insertions(+), 23 deletions(-) diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index cbe989f93558..efa44e026c51 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -83,13 +83,9 @@ struct TransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes used in reshape operators */ struct ReshapeAttrs : public tvm::AttrsNode { Array newshape; - bool reverse; TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") { TVM_ATTR_FIELD(newshape).describe( "The new shape. Should be compatible with the original shape."); - TVM_ATTR_FIELD(reverse) - .describe("Infer the special values from right to left if true") - .set_default(false); } }; // struct ReshapeAttrs diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index 815f24b6bda9..e4e81e3612fb 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -90,7 +90,6 @@ Array ReshapeCompute(const Attrs& attrs, const Array& in Expr MakeReshape(Expr data, Expr newshape) { auto attrs = make_object(); - attrs->reverse = false; static const Op& op = Op::Get("dyn.reshape"); return Call(op, {data, newshape}, Attrs(attrs), {}); } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 640943eac805..953640f73f88 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -455,13 +455,13 @@ RELAY_REGISTER_OP("transpose") TVM_REGISTER_NODE_TYPE(ReshapeAttrs); TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs); -Array infer_newshape(const Array& data_shape, const Attrs& attrs) { +Array infer_newshape(const Array& data_shape, const Attrs& attrs, bool reverse) { const auto* param = attrs.as(); Array oshape; Array ishape; Array newshape; - if (param->reverse) { + if (reverse) { ishape.Assign(data_shape.rbegin(), data_shape.rend()); newshape.Assign(param->newshape.rbegin(), param->newshape.rend()); } else { @@ -584,7 +584,6 @@ Array infer_newshape(const Array& data_shape, const Attrs& bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - const auto* param = attrs.as(); // types: [data, result] ICHECK_EQ(types.size(), 2); const auto* data = types[0].as(); @@ -594,16 +593,12 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, return false; } - const auto& oshape = infer_newshape(data->shape, attrs); + const auto& oshape = infer_newshape(data->shape, attrs, false); // Verify that the sum of dimensions in the output shape is the sum of // dimensions in the input shape Array data_shape; - if (param->reverse) { - data_shape.Assign(data->shape.rbegin(), data->shape.rend()); - } else { - data_shape = data->shape; - } + data_shape = data->shape; bool found_dynamic = false; int64_t oshape_sum = 1; @@ -633,12 +628,58 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, << "Input tensor shape and reshaped shape are not compatible"; } - if (param->reverse) { - reporter->Assign(types[1], - TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); - } else { - reporter->Assign(types[1], TensorType(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); + return true; +} + +bool ReverseReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + ICHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + ICHECK(types[0].as()) + << "reshape: expect input type to be TensorType but get " << types[0]; + return false; + } + + const auto& oshape = infer_newshape(data->shape, attrs, true); + + // Verify that the sum of dimensions in the output shape is the sum of + // dimensions in the input shape + Array data_shape; + data_shape.Assign(data->shape.rbegin(), data->shape.rend()); + + bool found_dynamic = false; + int64_t oshape_sum = 1; + for (auto& x : oshape) { + // Check if we have a dynamic shape. If we do, we can't verify if the + // reshape is valid. Dynamic shapes are marker by using Any, but can also + // occur from SizeVar's. In the case of SizeVar, the shape expression can + // be an AST. We can't easily check if we have an AST because of a ShapeVar + // or some other reason, so our check for dynamic shape is just if we can + // convert the shape to in integer or not. + if (!x->IsInstance()) { + found_dynamic = true; + break; + } + oshape_sum *= Downcast(x)->value; } + int64_t data_shape_sum = 1; + for (auto& x : data_shape) { + if (!x->IsInstance()) { + found_dynamic = true; + break; + } + data_shape_sum *= Downcast(x)->value; + } + if (!found_dynamic) { + ICHECK_EQ(oshape_sum, data_shape_sum) + << "Input tensor shape and reshaped shape are not compatible"; + } + + reporter->Assign(types[1], + TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); return true; } @@ -701,7 +742,7 @@ Array ReshapeCompute(const Attrs& attrs, const Array& in } if (newshape_has_any) { - newshape = infer_newshape(inputs[0]->shape, attrs); + newshape = infer_newshape(inputs[0]->shape, attrs, false); } return {topi::reshape(inputs[0], newshape)}; } @@ -709,7 +750,6 @@ Array ReshapeCompute(const Attrs& attrs, const Array& in Expr MakeReshape(Expr data, Array newshape) { auto attrs = make_object(); attrs->newshape = std::move(newshape); - attrs->reverse = false; static const Op& op = Op::Get("reshape"); return Call(op, {data}, Attrs(attrs), {}); } @@ -2867,7 +2907,6 @@ RELAY_REGISTER_OP("auto_scheduler_layout_transform") Expr MakeReverseReshape(Expr data, Array newshape) { auto attrs = make_object(); attrs->newshape = std::move(newshape); - attrs->reverse = true; static const Op& op = Op::Get("contrib_reverse_reshape"); return Call(op, {data}, Attrs(attrs), {}); } @@ -2892,7 +2931,7 @@ example below:: .set_attrs_type() .add_argument("data", "Tensor", "The input tensor.") .set_support_level(10) - .add_type_rel("Reshape", ReshapeRel) + .add_type_rel("ReverseReshape", ReverseReshapeRel) .set_attr("FTVMCompute", ReshapeCompute) .set_attr("TOpPattern", kInjective); From 11e69ec030b5386d6bd11ee9c9d7fe185aff0930 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 18 Dec 2020 10:34:45 -0800 Subject: [PATCH 2/4] Formatting. --- src/relay/op/tensor/transform.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 953640f73f88..d7984527b39d 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -455,7 +455,8 @@ RELAY_REGISTER_OP("transpose") TVM_REGISTER_NODE_TYPE(ReshapeAttrs); TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs); -Array infer_newshape(const Array& data_shape, const Attrs& attrs, bool reverse) { +Array infer_newshape(const Array& data_shape, const Attrs& attrs, + bool reverse) { const auto* param = attrs.as(); Array oshape; Array ishape; @@ -679,7 +680,7 @@ bool ReverseReshapeRel(const Array& types, int num_inputs, const Attrs& at } reporter->Assign(types[1], - TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); + TensorType(Array(oshape.rbegin(), oshape.rend()), data->dtype)); return true; } From b1385dc3bf19c778be11178011406fccdac6e6b3 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 18 Dec 2020 11:24:40 -0800 Subject: [PATCH 3/4] Fix test. --- tests/python/contrib/test_arm_compute_lib/test_reshape.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/contrib/test_arm_compute_lib/test_reshape.py b/tests/python/contrib/test_arm_compute_lib/test_reshape.py index 9364c6b1a478..94942727416a 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_reshape.py +++ b/tests/python/contrib/test_arm_compute_lib/test_reshape.py @@ -50,7 +50,6 @@ def _get_expected_codegen(input_shape, output_shape, dtype): "newshape": [[str(s) for s in output_shape]], "shape": [[list(output_shape)]], "dtype": [[dtype]], - "reverse": [["0"]], }, } From b74627714f569dcf3d7de1cde8eb5cc97e757468 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Wed, 30 Dec 2020 11:41:07 -0800 Subject: [PATCH 4/4] Changed name to match google c format. --- src/relay/op/tensor/transform.cc | 10 +++++----- src/relay/op/tensor/transform.h | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index d7984527b39d..f8d6e0d1fc68 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -455,8 +455,8 @@ RELAY_REGISTER_OP("transpose") TVM_REGISTER_NODE_TYPE(ReshapeAttrs); TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs); -Array infer_newshape(const Array& data_shape, const Attrs& attrs, - bool reverse) { +Array InferNewShape(const Array& data_shape, const Attrs& attrs, + bool reverse) { const auto* param = attrs.as(); Array oshape; Array ishape; @@ -594,7 +594,7 @@ bool ReshapeRel(const Array& types, int num_inputs, const Attrs& attrs, return false; } - const auto& oshape = infer_newshape(data->shape, attrs, false); + const auto& oshape = InferNewShape(data->shape, attrs, false); // Verify that the sum of dimensions in the output shape is the sum of // dimensions in the input shape @@ -644,7 +644,7 @@ bool ReverseReshapeRel(const Array& types, int num_inputs, const Attrs& at return false; } - const auto& oshape = infer_newshape(data->shape, attrs, true); + const auto& oshape = InferNewShape(data->shape, attrs, true); // Verify that the sum of dimensions in the output shape is the sum of // dimensions in the input shape @@ -743,7 +743,7 @@ Array ReshapeCompute(const Attrs& attrs, const Array& in } if (newshape_has_any) { - newshape = infer_newshape(inputs[0]->shape, attrs, false); + newshape = InferNewShape(inputs[0]->shape, attrs, false); } return {topi::reshape(inputs[0], newshape)}; } diff --git a/src/relay/op/tensor/transform.h b/src/relay/op/tensor/transform.h index 34aaf4689a59..a3770ff9cd8d 100644 --- a/src/relay/op/tensor/transform.h +++ b/src/relay/op/tensor/transform.h @@ -195,7 +195,7 @@ static inline Array> ConcatenateLayout(const Attrs& attrs, * \param attrs The attributes. * \return Output shape. */ -Array infer_newshape(const Array& data_shape, const Attrs& attrs); +Array InferNewShape(const Array& data_shape, const Attrs& attrs); } // namespace relay } // namespace tvm