Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,9 @@ struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
/*! \brief Attributes used in reshape operators */
struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
Array<Integer> newshape;
bool reverse;
TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
TVM_ATTR_FIELD(newshape).describe(
"The new shape. Should be compatible with the original shape.");
TVM_ATTR_FIELD(reverse)
.describe("Infer the special values from right to left if true")
.set_default(false);
}
}; // struct ReshapeAttrs

Expand Down
1 change: 0 additions & 1 deletion src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& in

Expr MakeReshape(Expr data, Expr newshape) {
auto attrs = make_object<ReshapeAttrs>();
attrs->reverse = false;
static const Op& op = Op::Get("dyn.reshape");
return Call(op, {data, newshape}, Attrs(attrs), {});
}
Expand Down
76 changes: 58 additions & 18 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,13 +455,14 @@ RELAY_REGISTER_OP("transpose")
TVM_REGISTER_NODE_TYPE(ReshapeAttrs);
TVM_REGISTER_NODE_TYPE(ReshapeLikeAttrs);

Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const Attrs& attrs) {
Array<IndexExpr> InferNewShape(const Array<IndexExpr>& data_shape, const Attrs& attrs,
bool reverse) {
const auto* param = attrs.as<ReshapeAttrs>();
Array<IndexExpr> oshape;
Array<IndexExpr> ishape;
Array<Integer> newshape;

if (param->reverse) {
if (reverse) {
ishape.Assign(data_shape.rbegin(), data_shape.rend());
newshape.Assign(param->newshape.rbegin(), param->newshape.rend());
} else {
Expand Down Expand Up @@ -584,7 +585,6 @@ Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const Attrs&

bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
const auto* param = attrs.as<ReshapeAttrs>();
// types: [data, result]
ICHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
Expand All @@ -594,16 +594,12 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return false;
}

const auto& oshape = infer_newshape(data->shape, attrs);
const auto& oshape = InferNewShape(data->shape, attrs, false);

// Verify that the sum of dimensions in the output shape is the sum of
// dimensions in the input shape
Array<IndexExpr> data_shape;
if (param->reverse) {
data_shape.Assign(data->shape.rbegin(), data->shape.rend());
} else {
data_shape = data->shape;
}
data_shape = data->shape;

bool found_dynamic = false;
int64_t oshape_sum = 1;
Expand Down Expand Up @@ -633,12 +629,58 @@ bool ReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "Input tensor shape and reshaped shape are not compatible";
}

if (param->reverse) {
reporter->Assign(types[1],
TensorType(Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
} else {
reporter->Assign(types[1], TensorType(oshape, data->dtype));
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}

bool ReverseReshapeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// types: [data, result]
ICHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
ICHECK(types[0].as<IncompleteTypeNode>())
<< "reshape: expect input type to be TensorType but get " << types[0];
return false;
}

const auto& oshape = InferNewShape(data->shape, attrs, true);

// Verify that the sum of dimensions in the output shape is the sum of
// dimensions in the input shape
Array<IndexExpr> data_shape;
data_shape.Assign(data->shape.rbegin(), data->shape.rend());

bool found_dynamic = false;
int64_t oshape_sum = 1;
for (auto& x : oshape) {
// Check if we have a dynamic shape. If we do, we can't verify if the
// reshape is valid. Dynamic shapes are marker by using Any, but can also
// occur from SizeVar's. In the case of SizeVar, the shape expression can
// be an AST. We can't easily check if we have an AST because of a ShapeVar
// or some other reason, so our check for dynamic shape is just if we can
// convert the shape to in integer or not.
if (!x->IsInstance<tvm::Integer::ContainerType>()) {
found_dynamic = true;
break;
}
oshape_sum *= Downcast<tvm::Integer>(x)->value;
}
int64_t data_shape_sum = 1;
for (auto& x : data_shape) {
if (!x->IsInstance<tvm::Integer::ContainerType>()) {
found_dynamic = true;
break;
}
data_shape_sum *= Downcast<tvm::Integer>(x)->value;
}
if (!found_dynamic) {
ICHECK_EQ(oshape_sum, data_shape_sum)
<< "Input tensor shape and reshaped shape are not compatible";
}

reporter->Assign(types[1],
TensorType(Array<IndexExpr>(oshape.rbegin(), oshape.rend()), data->dtype));
return true;
}

Expand Down Expand Up @@ -701,15 +743,14 @@ Array<te::Tensor> ReshapeCompute(const Attrs& attrs, const Array<te::Tensor>& in
}

if (newshape_has_any) {
newshape = infer_newshape(inputs[0]->shape, attrs);
newshape = InferNewShape(inputs[0]->shape, attrs, false);
}
return {topi::reshape(inputs[0], newshape)};
}

Expr MakeReshape(Expr data, Array<Integer> newshape) {
auto attrs = make_object<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
attrs->reverse = false;
static const Op& op = Op::Get("reshape");
return Call(op, {data}, Attrs(attrs), {});
}
Expand Down Expand Up @@ -2867,7 +2908,6 @@ RELAY_REGISTER_OP("auto_scheduler_layout_transform")
Expr MakeReverseReshape(Expr data, Array<Integer> newshape) {
auto attrs = make_object<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
attrs->reverse = true;
static const Op& op = Op::Get("contrib_reverse_reshape");
return Call(op, {data}, Attrs(attrs), {});
}
Expand All @@ -2892,7 +2932,7 @@ example below::
.set_attrs_type<ReshapeAttrs>()
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(10)
.add_type_rel("Reshape", ReshapeRel)
.add_type_rel("ReverseReshape", ReverseReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", ReshapeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);

Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/tensor/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ static inline Array<Array<Layout>> ConcatenateLayout(const Attrs& attrs,
* \param attrs The attributes.
* \return Output shape.
*/
Array<IndexExpr> infer_newshape(const Array<IndexExpr>& data_shape, const Attrs& attrs);
Array<IndexExpr> InferNewShape(const Array<IndexExpr>& data_shape, const Attrs& attrs);

} // namespace relay
} // namespace tvm
Expand Down
1 change: 0 additions & 1 deletion tests/python/contrib/test_arm_compute_lib/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def _get_expected_codegen(input_shape, output_shape, dtype):
"newshape": [[str(s) for s in output_shape]],
"shape": [[list(output_shape)]],
"dtype": [[dtype]],
"reverse": [["0"]],
},
}

Expand Down