Skip to content

Conversation

@sunggg
Copy link
Contributor

@sunggg sunggg commented Mar 12, 2023

As discussed in previous Unity open dev meeting, this PR implements data-dependent operation of reshape.

Representative examples

This PR realizes data-dependent operation of reshape as follows.

# Target shape will be computed at runtime and stored in tensor
target_shape: R.Tensor(ndim=1) = ...
# With new 'tensor_to_shape`, we can convert tensor values to ShapeExpr
lv: R.Shape(ndim=2) = R.tensor_to_shape(target_shape)
# Reshape is extended to take a Var as long as it is bound with ShapeExpr
gv: R.Tensor(ndim=2, dtype="float32") = R.reshape(data, lv)

Also, FoldConstant is extended to support tensor_to_shape.

# Before `FoldConstant`: c0, c1 are constants
lv0 = R.add(c0, c0)
target_shape = R.multiply(lv0, c1)
lv2: R.Shape(ndim=2) = R.tensor_to_shape(target_shape)
gv: R.Tensor(ndim=2, dtype="float32") = R.reshape(data, lv2)

# After `FoldConstant`
gv: R.Tensor((16, 16), dtype="float32") = R.reshape(data, R.shape([16, 16]))

Summary of changes

  • Introduce new builtin tensor_to_shape
  • Current reshape takes the target shape in ShapeExpr | Array[PrimExpr]. This PR extends this to take Var only when it is bound to ShapeExpr.
  • Extend FoldConstant pass to support tensor_to_shape

03/18/2023 Update

Turned out if we implement target_to_shape as a builtin, its lowering happens too late and we cannot lower the following reshape-like ops since they need at least symbolic shape to legalize. Therefore, target_to_shape should be lowered before we lower reshape. New change extends existing SimplifyNormInference pass to serve as a generic pass to decompose composite operators like tensor_to_shape, attention, erf, etc.

Follow-up PRs

  • Formal introduction of composite ops
  • Refactor of ConstantFold pass to handle such composite ops

@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 12, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@sunggg
Copy link
Contributor Author

sunggg commented Mar 12, 2023

@tqchen
Copy link
Member

tqchen commented Mar 13, 2023

please update via

git rebase --onto  upstream/unity upstream/unity-rebase-backup-2023-03-13


def test_op_tensor_to_shape():
out_shape = run_cpu(
TensorToShapeTest, "run_tensor_to_shape", tvm.nd.array(np.array([1, 2, 3]).astype("int64"))
Copy link
Contributor

@slyubomirsky slyubomirsky Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the demonstrated behavior is correct. We are converting the rank-1 tensor [1, 2, 3] into a shape (1, 2, 3). I think the problem here is that ndim means different things for tensors (where it means the rank) and shapes (where it means the number of dimensions denoted by the shape).

@sunggg sunggg force-pushed the data_dep_reshape branch from 47eead9 to b75f11f Compare March 13, 2023 20:24
ICHECK(call->args.size() == 1);
ICHECK(call->args[0]->struct_info_.defined());
const auto* tsinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
ICHECK(tsinfo && tsinfo->shape.defined());
Copy link
Contributor

@slyubomirsky slyubomirsky Mar 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this condition is more restrictive than necessary. I think it's okay for the shape not to be known at compile time and to return ShapeStructInfo(ndim=-1) if the input rank/shape is unknown at compile time.

Edit: I guess you use the rank in vm_builtin_lower.cc but I don't see why it has to be implemented that way. You could check all the necessary properties dynamically (inside the builtin)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean these lines?

  // define symbolic variables
  Array<PrimExpr> shape_var;
  for (int i = 0; i < sinfo->ndim; i++) {
    shape_var.push_back(tir::Var("x", DataType::Int(64)));
  }

Initially, I also wanted to support the unknown rank but realized it is trickier than I thought.
The problem is you need to insert these symbolic variables at compile-time so we need this info.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly are the new shape vars needed for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To MatchCast and put them in the ShapeExpr.

// define symbolic variables
Array<PrimExpr> shape_var;
for (int i = 0; i < sinfo->ndim; i++) {
    shape_var.push_back(tir::Var("x", DataType::Int(64)));
 }

// bind symbolic variables to the shape tuple
relax::Var var("y", ShapeStructInfo(shape_var));
builder_->EmitNormalized(MatchCast(var, call, ShapeStructInfo(shape_var)));
return ShapeExpr(shape_var);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That much I understand, what I am more curious about is why we need to construct this shape expression

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does lowering require the dimensions to be bound to a variable? That doesn't seem like a restriction that needs to exist.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slyubomirsky To pre-allocate memory for TIR functions? TE compute takes the output shape in Array<PrimExpr> as the first argument:

Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name, std::string tag,
               Map<String, ObjectRef> attrs)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great discussion folks! Looks the limitation for shape_expr is from TE, interesting

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmmm, I think we should put this on the list to fix later. Let's please note this limitation in a comment, since it is not obvious from looking at the code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this in the comment. Would it be good to go now?

TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data) {
NDArray arr = data;
if (data->device.device_type != kDLCPU) {
arr = data.CopyTo(DLDevice{kDLCPU, 0});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this copy fail if the NDArray is on another device? I'm a little hesitant to have an op that just will not work depending on the device.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible that we have a shape tensor not on host device?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure where this might come up. I wouldn't be surprised if, say, we import an ONNX model (this is where this use-case first came up) that we mean to run on GPU and every tensor (including those that stand for shapes) is stored on GPU.

Copy link
Contributor Author

@sunggg sunggg Mar 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on my current understanding, we want to keep the shape-related computation on the host side since its result could be related to the memory planning. Even for general GPU execution besides this case, we are running output shape computation (inserted by VMShapeLower pass) on the host side.

@sunggg sunggg force-pushed the data_dep_reshape branch from 11842e3 to 802a01b Compare March 19, 2023 00:37
@sunggg sunggg changed the title [Unity][BuiltinOp][Transform] Introduce data-dependent operation of reshape and its constant folding [Unity][Transform] Introduce data-dependent operation of reshape and its constant folding Mar 19, 2023
@sunggg
Copy link
Contributor Author

sunggg commented Mar 19, 2023

This PR extends existing SimplifyNormInference pass to serve as a generic pass to decompose composite operators like tensor_to_shape, attention, erf, etc. Would you check if you are okay with this change?
cc. @Hzfengsy @SiriusNEO

Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except for a naming issue

Map<Expr, Expr> batch_norm_map_;
};

class OpDecomposer : public ExprMutator {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of OpDecoposer is too generic since it's only for tensor_to_shape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can extend this function to add other composite ops like erf, attention which will be upcoming soon.

@SiriusNEO
Copy link
Contributor

@sunggg Hi actually we also do some changes in SimplifyNormInference but it has not been upstreamed yet (mlc-ai/relax#162). Notice that you don't change the part which simplifies BatchNorm, I can find some time to rebase my changes on yours.

@sunggg sunggg force-pushed the data_dep_reshape branch from 3de7d67 to 24a9fe7 Compare March 20, 2023 02:56
@sunggg
Copy link
Contributor Author

sunggg commented Mar 20, 2023

@yongwww @slyubomirsky would you take another look?

@yongwww
Copy link
Member

yongwww commented Mar 20, 2023

@sunggg it would be good to rebase it via git rebase --onto upstream/unity upstream/unity-rebase-backup-2023-03-20

@sunggg sunggg force-pushed the data_dep_reshape branch from a07929a to f4095b6 Compare March 20, 2023 20:58
Copy link
Member

@yongwww yongwww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for addressing my concerns!

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My concerns have been addressed, thank you

@zxybazh zxybazh merged commit 675a22e into apache:unity Mar 21, 2023
@zxybazh
Copy link
Member

zxybazh commented Mar 21, 2023

Thanks @sunggg for the thoughtful design and quick implementation!

tqchen pushed a commit that referenced this pull request Apr 1, 2023
…its constant folding (#14282)

* FEAT: Support data-dependent operation of reshape

* FEAT: Support constant folding with data-dependent reshape

* fix

* remove empty line

* reflect feedback

* Lift the lowering of tensor_to_shape from builtin to DecomposeCompositeOps pass

* fix and comment

* fix

* add comments

* reflect feedback

* add comment

* fix
tqchen pushed a commit that referenced this pull request Apr 1, 2023
…its constant folding (#14282)

* FEAT: Support data-dependent operation of reshape

* FEAT: Support constant folding with data-dependent reshape

* fix

* remove empty line

* reflect feedback

* Lift the lowering of tensor_to_shape from builtin to DecomposeCompositeOps pass

* fix and comment

* fix

* add comments

* reflect feedback

* add comment

* fix
tqchen pushed a commit that referenced this pull request Apr 1, 2023
…its constant folding (#14282)

* FEAT: Support data-dependent operation of reshape

* FEAT: Support constant folding with data-dependent reshape

* fix

* remove empty line

* reflect feedback

* Lift the lowering of tensor_to_shape from builtin to DecomposeCompositeOps pass

* fix and comment

* fix

* add comments

* reflect feedback

* add comment

* fix
tqchen pushed a commit to mlc-ai/relax that referenced this pull request Apr 5, 2023
PR (apache/tvm#14282) refactors the pass
`SimplifyNorm` to `DecomposeOps`. Last rebase leaves the conflicts to be
fixed and this PR merges apache/tvm#14282 and
#162 together.

The changes mainly include:
- func pass -> module pass (Because sometimes we don't want simplify all
functions in a module)
- Add a `mode` argument to indicate whether it is a training
simplification or eval simplification.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants