Skip to content

Conversation

@sunggg
Copy link
Contributor

@sunggg sunggg commented Mar 31, 2023

In Unity, we have a clear distinction between tensor and shape: we have ShapeExpr and ShapeStructInfo in AST, ShapeTuple in runtime container. Meanwhile, most of operators and their TOPI implementations are defined with tensors. For example, relax.take is defined as follows:

// Operator definition
TVM_REGISTER_OP("relax.take")
    .set_attrs_type<TakeAttrs>()
    .set_num_inputs(2)
    .add_argument("x", "Tensor", "The source tensor.")
    .add_argument("indices", "Tensor", "The indices of the values to extract.")
    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTake);
       
StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) {
  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
   // Assume inputs are `Tensor` 
  TensorStructInfo data_sinfo = input_sinfo[0];
  TensorStructInfo indices_sinfo = input_sinfo[1];
  ...
}

// TOPI
inline Tensor take(const Tensor& a, const Tensor& indices, int batch_dims, int axis, ...)

To allow the shape computation, this PR introduces a shape_to_tensor op that converts ShapeTuple to NDArray at runtime. This enables the common shape computation patterns like following:

shape_var: R.Shape(ndim=4)
lv: R.Tensor((4,), dtype="int64") = R.shape_to_tensor(shape_var)
lv1: R.Tensor((1,), dtype="int64") = R.take(lv, indices, axis=0)
lv2: R.Tensor((1, 1), dtype="int64") = R.expand_dims(lv1, axis=[0])
gv: R.Tensor((1, 1), dtype="int64") = R.concat((lv2,), axis=0)

It's worth noting that tensor_to_shape op is already introduced in #14282, so roundtrip between shape and tensor would be now possible.

Currently, this op requires special handling in FoldConstant pass since this pass is only able to evaluate TIR primfunc, not PackedFunc. Once we extend FoldConstant to support PackedFunc evaluation, we should be able to remove these unnecessary special handling.

cc. @jwfromm @yongwww @psrivas2 @slyubomirsky @tqchen

@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 31, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

ctx->ReportFatal(Diagnostic::Error(call)
<< op << " requires the input " << op->arguments[i]->name
<< " to be Tensor. However, the given one is "
<< " to be Tensor. However, the given one has a "
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated to this PR, but found the current message confusing during the debugging.

shape_tuple: tvm.runtime.ShapeTuple
Shape tuple that we want to convert to NDArray at runtime
"""
return tvm.nd.array([int(v) for v in shape_tuple])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we assume it's always on cpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. shape tuple and shape computation happen on the CPU side.

@tqchen tqchen force-pushed the unity branch 2 times, most recently from a425bc7 to 5c8b7af Compare April 1, 2023 20:00
@sunggg sunggg force-pushed the op_shape_to_tensor branch from 9870c35 to fc0cc25 Compare April 2, 2023 23:54
Copy link
Member

@yongwww yongwww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@vinx13 vinx13 merged commit db01567 into apache:unity Apr 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants