-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Unity][Transform] Common Subexpression Elimination #14361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
| // 3. Scalar constants (not much benefit from binding to a var) | ||
| if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() || | ||
| e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() || | ||
| (e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's up to us if we think it's worth rebinding scalar constants. I doubt it, though.
| Expr VisitExpr_(const FunctionNode* func) override { | ||
| // for an inner function, we will do CSE on its body | ||
| Expr new_body = ExprMutator::VisitExpr(func->body); | ||
| if (new_body.same_as(func->body)) { | ||
| return GetRef<Expr>(func); | ||
| } | ||
| return Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span); | ||
| } | ||
|
|
||
| // this should happen only for the inner function case | ||
| Expr VisitExpr_(const SeqExprNode* seq) override { | ||
| bool all_unchanged = true; | ||
| Array<BindingBlock> new_blocks; | ||
| // apply CSE within dataflow blocks only | ||
| for (auto block : seq->blocks) { | ||
| if (const DataflowBlockNode* df_block = block.as<DataflowBlockNode>()) { | ||
| auto new_df_block = EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block)); | ||
| if (!new_df_block.same_as(block)) { | ||
| new_blocks.push_back(new_df_block); | ||
| all_unchanged = false; | ||
| continue; | ||
| } | ||
| } | ||
| new_blocks.push_back(block); | ||
| } | ||
|
|
||
| if (all_unchanged) { | ||
| return GetRef<Expr>(seq); | ||
| } | ||
| // do not visit the body | ||
| return SeqExpr(new_blocks, seq->body, seq->span); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seemed like a bit of a strange thing to have to implement in a dataflow block pass. It could be avoided if we require lambda-lifting (arguably we should for all DF block passes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update: Based on the Unity Community Meeting discussion, it doesn't sound like there is much appetite for imposing phase orderings like this, so I would be interested instead if there is a clean way to deal with local functions in dataflow block passes (generalizing the approach shown here, for example). I am sure that other dataflow block passes don't handle the local function case and might exhibit strange bugs if given a program with local functions
e: Possible solution: Change DataflowBlockMutator to look for inner functions and process them separately. This way the pass_func for dataflow block passes can safely just ignore inner functions and it would be handled by the pass infrastructure
| lv0 = R.add(x, y) | ||
| # can combine with canonicalizing bindings | ||
| # and getting rid of unused bindings to eliminate this line too | ||
| lv1 = lv0 | ||
| gv = R.multiply(lv0, lv1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't want to reimplement the functionality of the canonicalize bindings pass, but we could fold that in here if we want to (I don't think it's a good idea).
| # can further clean this up | ||
| # using canonicalize bindings, eliminate unused bindings, and CSE again | ||
| lv0 = bar(x) | ||
| lv1 = lv0 | ||
| lv2 = R.add(lv0, lv1) | ||
| lv3 = lv0 | ||
| lv4 = lv0 | ||
| lv5 = R.add(lv3, lv4) | ||
| lv6 = R.add(lv2, lv5) | ||
| gv = lv6 | ||
| R.output(gv) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not ideal because there are still repeated additions. @psrivas2 once proposed having a canonicalization pass that runs until fixpoint. This is an example where running such a pass will help.
|
cc @masahi |
|
Nice work! In what scenarios do we want to apply this pass? |
|
The ideal scenario would be if some model (especially an imported one) repeatedly invokes the same operation (an expensive one). We encountered that pattern already in some imported models. CSE is a classic optimization. |
Yeah I agree. One issue that comes to my mind is that it might be critical to decide what subexprs to eliminate. If it's lightweight and inlinable to surrounding ops, then we should probably reject to eliminate the redundancy. In classical settings, the trade-off of CSE is to enlarge the live range of some registers, which might cause performance regression due to register spill. But for DL workloads, it's likely that the model simply can not be deployed given a certain amount of GPU Memory, which is more severe than running slower. Of course the mem planner can try smarter strategies similar with live range split, by recomputing some calculations. |
|
In this case, the bindings will be live only within a single dataflow block, so I don't think there will be many issues with keeping values live for much longer than they would be otherwise. It would be easy to add a heuristic for deciding when we shouldn't deduplicate. edit: Playing around with recomputation to reduce memory requirements is definitely a strategy we can try. |
psrivas2
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for handling the inner func case here too!
Going forward, we should definitely run lambda lifting pass ahead of time to avoid this burden of dataflow pass writers or modify the DataflowBlockMutator to handle inner functions.
yongwww
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
| with R.dataflow(): | ||
| # we are going to do CSE inside the local function | ||
| @R.function | ||
| def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if we should handle the case: duplicate gloabl/local functions. I met this case when lower jax to hlo.
@R.function
def bar1(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
z = R.add(R.add(y, y), R.add(y, y))
with R.dataflow():
lv0 = R.add(y, y)
lv1 = R.add(y, y)
lv2 = R.add(lv0, lv1)
gv = lv2
R.output(gv)
return R.add(z, gv)
@R.function
def bar2(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
z = R.add(R.add(y, y), R.add(y, y))
with R.dataflow():
lv0 = R.add(y, y)
lv1 = R.add(y, y)
lv2 = R.add(lv0, lv1)
gv = lv2
R.output(gv)
return R.add(z, gv)
lv0 = bar1(x)
lv1 = bar2(x)
lv2 = R.add(lv0, lv1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually will handle duplicate local functions. We could do a module-level analysis too for global functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow, looks great! module-level analysis for global funcs will be helpful for my case, it could be a todo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be clear, it would only handle it if the bindings happen inside a DataflowBlock 🤪
We might want to consider expanding this pass to handle non-dataflow sections as well. Purity tracking would help with that
* [Unity][Pass] Add pass for CSE within dataflow * Fill in CSE definition and test cases * Missing trailing newline --------- Co-authored-by: Prakalp Srivastava <prakalp@octoml.ai>
* [Unity][Pass] Add pass for CSE within dataflow * Fill in CSE definition and test cases * Missing trailing newline --------- Co-authored-by: Prakalp Srivastava <prakalp@octoml.ai>
* [Unity][Pass] Add pass for CSE within dataflow * Fill in CSE definition and test cases * Missing trailing newline --------- Co-authored-by: Prakalp Srivastava <prakalp@octoml.ai>
This PR implements a dataflow block CSE transformation. Since we use ANF internally, the only nesting really occurs with tuples. The pass only needs to look at the RHS of bindings to determine if we have encountered any subexpressions.
Co-authored by @psrivas2 Prakalp Srivastava prakalp@octoml.ai