Skip to content

Conversation

@slyubomirsky
Copy link
Contributor

This PR implements a dataflow block CSE transformation. Since we use ANF internally, the only nesting really occurs with tuples. The pass only needs to look at the RHS of bindings to determine if we have encountered any subexpressions.

Co-authored by @psrivas2 Prakalp Srivastava prakalp@octoml.ai

@tvm-bot
Copy link
Collaborator

tvm-bot commented Mar 21, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

// 3. Scalar constants (not much benefit from binding to a var)
if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
(e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's up to us if we think it's worth rebinding scalar constants. I doubt it, though.

Comment on lines +102 to +133
Expr VisitExpr_(const FunctionNode* func) override {
// for an inner function, we will do CSE on its body
Expr new_body = ExprMutator::VisitExpr(func->body);
if (new_body.same_as(func->body)) {
return GetRef<Expr>(func);
}
return Function(func->params, new_body, func->ret_struct_info, func->attrs, func->span);
}

// this should happen only for the inner function case
Expr VisitExpr_(const SeqExprNode* seq) override {
bool all_unchanged = true;
Array<BindingBlock> new_blocks;
// apply CSE within dataflow blocks only
for (auto block : seq->blocks) {
if (const DataflowBlockNode* df_block = block.as<DataflowBlockNode>()) {
auto new_df_block = EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block));
if (!new_df_block.same_as(block)) {
new_blocks.push_back(new_df_block);
all_unchanged = false;
continue;
}
}
new_blocks.push_back(block);
}

if (all_unchanged) {
return GetRef<Expr>(seq);
}
// do not visit the body
return SeqExpr(new_blocks, seq->body, seq->span);
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seemed like a bit of a strange thing to have to implement in a dataflow block pass. It could be avoided if we require lambda-lifting (arguably we should for all DF block passes).

Copy link
Contributor Author

@slyubomirsky slyubomirsky Mar 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: Based on the Unity Community Meeting discussion, it doesn't sound like there is much appetite for imposing phase orderings like this, so I would be interested instead if there is a clean way to deal with local functions in dataflow block passes (generalizing the approach shown here, for example). I am sure that other dataflow block passes don't handle the local function case and might exhibit strange bugs if given a program with local functions

e: Possible solution: Change DataflowBlockMutator to look for inner functions and process them separately. This way the pass_func for dataflow block passes can safely just ignore inner functions and it would be handled by the pass infrastructure

Comment on lines +47 to +51
lv0 = R.add(x, y)
# can combine with canonicalizing bindings
# and getting rid of unused bindings to eliminate this line too
lv1 = lv0
gv = R.multiply(lv0, lv1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to reimplement the functionality of the canonicalize bindings pass, but we could fold that in here if we want to (I don't think it's a good idea).

Comment on lines +169 to +179
# can further clean this up
# using canonicalize bindings, eliminate unused bindings, and CSE again
lv0 = bar(x)
lv1 = lv0
lv2 = R.add(lv0, lv1)
lv3 = lv0
lv4 = lv0
lv5 = R.add(lv3, lv4)
lv6 = R.add(lv2, lv5)
gv = lv6
R.output(gv)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not ideal because there are still repeated additions. @psrivas2 once proposed having a canonicalization pass that runs until fixpoint. This is an example where running such a pass will help.

@psrivas2
Copy link
Contributor

cc @masahi

@spectrometerHBH
Copy link
Contributor

Nice work!

In what scenarios do we want to apply this pass?

@slyubomirsky slyubomirsky requested a review from masahi March 22, 2023 18:05
@slyubomirsky
Copy link
Contributor Author

The ideal scenario would be if some model (especially an imported one) repeatedly invokes the same operation (an expensive one). We encountered that pattern already in some imported models. CSE is a classic optimization.

@spectrometerHBH
Copy link
Contributor

spectrometerHBH commented Mar 22, 2023

The ideal scenario would be if some model (especially an imported one) repeatedly invokes the same operation (an expensive one). We encountered that pattern already in some imported models. CSE is a classic optimization.

Yeah I agree. One issue that comes to my mind is that it might be critical to decide what subexprs to eliminate. If it's lightweight and inlinable to surrounding ops, then we should probably reject to eliminate the redundancy.

In classical settings, the trade-off of CSE is to enlarge the live range of some registers, which might cause performance regression due to register spill. But for DL workloads, it's likely that the model simply can not be deployed given a certain amount of GPU Memory, which is more severe than running slower.

Of course the mem planner can try smarter strategies similar with live range split, by recomputing some calculations.

@slyubomirsky
Copy link
Contributor Author

slyubomirsky commented Mar 22, 2023

In this case, the bindings will be live only within a single dataflow block, so I don't think there will be many issues with keeping values live for much longer than they would be otherwise.

It would be easy to add a heuristic for deciding when we shouldn't deduplicate.

edit: Playing around with recomputation to reduce memory requirements is definitely a strategy we can try.

Copy link
Contributor

@psrivas2 psrivas2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for handling the inner func case here too!

Going forward, we should definitely run lambda lifting pass ahead of time to avoid this burden of dataflow pass writers or modify the DataflowBlockMutator to handle inner functions.

Copy link
Member

@yongwww yongwww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

with R.dataflow():
# we are going to do CSE inside the local function
@R.function
def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we should handle the case: duplicate gloabl/local functions. I met this case when lower jax to hlo.

            @R.function
            def bar1(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
                z = R.add(R.add(y, y), R.add(y, y))
                with R.dataflow():
                    lv0 = R.add(y, y)
                    lv1 = R.add(y, y)
                    lv2 = R.add(lv0, lv1)
                    gv = lv2
                    R.output(gv)
                return R.add(z, gv)
            
            @R.function
            def bar2(y: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
                z = R.add(R.add(y, y), R.add(y, y))
                with R.dataflow():
                    lv0 = R.add(y, y)
                    lv1 = R.add(y, y)
                    lv2 = R.add(lv0, lv1)
                    gv = lv2
                    R.output(gv)
                return R.add(z, gv)

            lv0 = bar1(x)
            lv1 = bar2(x)
            lv2 = R.add(lv0, lv1)

Copy link
Contributor Author

@slyubomirsky slyubomirsky Mar 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually will handle duplicate local functions. We could do a module-level analysis too for global functions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, looks great! module-level analysis for global funcs will be helpful for my case, it could be a todo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, it would only handle it if the bindings happen inside a DataflowBlock 🤪

We might want to consider expanding this pass to handle non-dataflow sections as well. Purity tracking would help with that

@slyubomirsky slyubomirsky merged commit f4d5964 into apache:unity Mar 27, 2023
tqchen pushed a commit that referenced this pull request Apr 1, 2023
* [Unity][Pass] Add pass for CSE within dataflow

* Fill in CSE definition and test cases

* Missing trailing newline

---------

Co-authored-by: Prakalp Srivastava <prakalp@octoml.ai>
tqchen pushed a commit that referenced this pull request Apr 1, 2023
* [Unity][Pass] Add pass for CSE within dataflow

* Fill in CSE definition and test cases

* Missing trailing newline

---------

Co-authored-by: Prakalp Srivastava <prakalp@octoml.ai>
tqchen pushed a commit that referenced this pull request Apr 1, 2023
* [Unity][Pass] Add pass for CSE within dataflow

* Fill in CSE definition and test cases

* Missing trailing newline

---------

Co-authored-by: Prakalp Srivastava <prakalp@octoml.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants