-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Unity][Transform] Introduce data-dependent operation of reshape and its constant folding #14282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
acbaa1f
4809080
1a7ae44
302e525
36cc64a
bdfaa6b
8ce3895
eb40129
f4095b6
600a94d
6022718
28c71b4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ | |
|
|
||
| #include <tvm/relax/analysis.h> | ||
| #include <tvm/relax/attrs/nn.h> | ||
| #include <tvm/relax/struct_info.h> | ||
| #include <tvm/relax/transform.h> | ||
|
|
||
| #include "utils.h" | ||
|
|
@@ -110,21 +111,63 @@ class NormInferenceSimplifier : public ExprMutator { | |
| Map<Expr, Expr> batch_norm_map_; | ||
| }; | ||
|
|
||
| class OpDecomposer : public ExprMutator { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The name of OpDecoposer is too generic since it's only for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can extend this function to add other composite ops like |
||
| public: | ||
| static Expr Decompose(Expr expr) { return OpDecomposer()(expr); } | ||
|
|
||
| private: | ||
| using ExprMutator::VisitExpr_; | ||
| Expr TensorToShape(const Call& call_node) { | ||
| ICHECK(call_node->struct_info_.defined()); | ||
| Expr expr = call_node->args[0]; | ||
| const ShapeStructInfoNode* sinfo = GetStructInfoAs<ShapeStructInfoNode>(call_node); | ||
| ICHECK(sinfo); | ||
| // call builtin function that converts tensor to shape tuple | ||
| // TODO(@sunggg): Register operator for "vm.builtin.tensor_to_shape" | ||
| Var call = builder_->Emit(Call(ExternFunc("vm.builtin.tensor_to_shape"), {expr}, {}, | ||
sunggg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| {GetRef<ShapeStructInfo>(sinfo)})); | ||
|
|
||
| // Operators like reshape take the output of `TensorToShape` as their output shape. | ||
| // Because TOPI expects to have such output shape in symbolic shape at least (i.e., | ||
| // Array<PrimExpr>), we define symbolic variables and returns them as a ShapeExpr. | ||
| Array<PrimExpr> shape_var; | ||
| for (int i = 0; i < sinfo->ndim; i++) { | ||
| shape_var.push_back(tir::Var("x", DataType::Int(64))); | ||
| } | ||
| // bind symbolic variables to the shape tuple | ||
| relax::Var var("y", ShapeStructInfo(shape_var)); | ||
| builder_->EmitNormalized(MatchCast(var, call, ShapeStructInfo(shape_var))); | ||
| return ShapeExpr(shape_var); | ||
| } | ||
|
|
||
| Expr VisitExpr_(const CallNode* call_node) final { | ||
| Call call = Downcast<Call>(VisitExprPostOrder_(call_node)); | ||
| if (call->op == tensor_to_shape_op_) { | ||
| return TensorToShape(call); | ||
| } else { | ||
| return call; | ||
| } | ||
| } | ||
|
|
||
| const Op& tensor_to_shape_op_ = Op::Get("relax.tensor_to_shape"); | ||
| }; | ||
|
|
||
| namespace transform { | ||
| Pass SimplifyNormInference() { | ||
| Pass DecomposeCompositeOps() { | ||
| runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = | ||
| [=](Function f, IRModule m, PassContext pc) { | ||
| f = Downcast<Function>(NormInferenceSimplifier::Simplify(f)); | ||
| // Remove original batch_norm op if it's not used. | ||
| f = Downcast<Function>(OpDecomposer::Decompose(f)); | ||
| // Remove original ops if it's not used. | ||
| return RemoveAllUnused(f); | ||
| }; | ||
| return CreateFunctionPass(/*pass_function=*/pass_func, // | ||
| /*opt_level=*/0, // | ||
| /*pass_name=*/"SimplifyNormInference", // | ||
| /*pass_name=*/"DecomposeCompositeOps", // | ||
| /*required=*/{}); | ||
| } | ||
|
|
||
| TVM_REGISTER_GLOBAL("relax.transform.SimplifyNormInference").set_body_typed(SimplifyNormInference); | ||
| TVM_REGISTER_GLOBAL("relax.transform.DecomposeCompositeOps").set_body_typed(DecomposeCompositeOps); | ||
|
|
||
| } // namespace transform | ||
| } // namespace relax | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -380,6 +380,40 @@ TVM_REGISTER_GLOBAL("vm.builtin.make_tuple").set_body([](TVMArgs args, TVMRetVal | |
| *rv = arr; | ||
| }); | ||
|
|
||
| TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data) { | ||
| NDArray arr = data; | ||
| if (data->device.device_type != kDLCPU) { | ||
sunggg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| arr = data.CopyTo(DLDevice{kDLCPU, 0}); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this copy fail if the NDArray is on another device? I'm a little hesitant to have an op that just will not work depending on the device.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible that we have a shape tensor not on host device?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure where this might come up. I wouldn't be surprised if, say, we import an ONNX model (this is where this use-case first came up) that we mean to run on GPU and every tensor (including those that stand for shapes) is stored on GPU.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Based on my current understanding, we want to keep the shape-related computation on the host side since its result could be related to the memory planning. Even for general GPU execution besides this case, we are running output shape computation (inserted by |
||
| } | ||
|
|
||
| ICHECK_EQ(arr->ndim, 1); | ||
| ICHECK_EQ(arr->dtype.code, kDLInt); | ||
|
|
||
| std::vector<int64_t> out_shape; | ||
| for (int i = 0; i < arr.Shape()[0]; ++i) { | ||
| int64_t result; | ||
| switch (arr->dtype.bits) { | ||
| case 16: { | ||
| result = reinterpret_cast<int16_t*>(arr->data)[i]; | ||
| break; | ||
| } | ||
| case 32: { | ||
| result = reinterpret_cast<int32_t*>(arr->data)[i]; | ||
| break; | ||
| } | ||
| case 64: { | ||
| result = reinterpret_cast<int64_t*>(arr->data)[i]; | ||
| break; | ||
| } | ||
| default: | ||
| LOG(FATAL) << "Unknown scalar int type: " << DLDataType2String(arr->dtype); | ||
| throw; | ||
| } | ||
| out_shape.push_back(result); | ||
| } | ||
| return ShapeTuple(out_shape); | ||
| }); | ||
|
|
||
| } // namespace relax_vm | ||
| } // namespace runtime | ||
| } // namespace tvm | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this condition is more restrictive than necessary. I think it's okay for the shape not to be known at compile time and to return
ShapeStructInfo(ndim=-1)if the input rank/shape is unknown at compile time.Edit: I guess you use the rank in
vm_builtin_lower.ccbut I don't see why it has to be implemented that way. You could check all the necessary properties dynamically (inside the builtin)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean these lines?
Initially, I also wanted to support the unknown rank but realized it is trickier than I thought.
The problem is you need to insert these symbolic variables at compile-time so we need this info.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What exactly are the new shape vars needed for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To MatchCast and put them in the
ShapeExpr.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That much I understand, what I am more curious about is why we need to construct this shape expression
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does lowering require the dimensions to be bound to a variable? That doesn't seem like a restriction that needs to exist.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@slyubomirsky To pre-allocate memory for TIR functions? TE compute takes the output shape in
Array<PrimExpr>as the first argument:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great discussion folks! Looks the limitation for shape_expr is from TE, interesting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmmm, I think we should put this on the list to fix later. Let's please note this limitation in a comment, since it is not obvious from looking at the code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this in the comment. Would it be good to go now?