-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Unity][Transform] Replace eligible operators with in-place versions in dataflow blocks #16129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c95d45f to
45eeb8c
Compare
…them. Also use pointers instead of non-const refs
| } | ||
|
|
||
| // Replace buffers in a PrimFunc according to the mapping. | ||
| tir::Stmt RemapBuffers(const tir::Stmt& stmt, const Map<tir::Buffer, tir::Buffer>& buffer_map) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could be handled by calling tir::Specialize.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooh! I didn't know about it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a nice one for a lot of signature manipulations. There may need to be some post-processing, but it should work in this case. If the selected in-place input is parameter i, and the output is parameter j, then it would be called as auto new_func = Specialize(func, {{func->params[j], func->buffer_map[func->params[i]]}}).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, using tir::Specialize still left some uses of the old output buffer around, which still seems like it would require the manual rewrite I had. It was odd. Namely, the old buffer was being used in the writes field of the block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did the output buffer have a DeclBuffer statement for it, or was it just used implicitly?
I think I ran into a similar issue, with #14565 to resolve it. That PR ended up not being merged, since RFC#70 would have required the DeclBuffer node. Unfortunately, #14778, which implements RFC#70, has been stuck for some time on ethos-u test cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was no DeclBuffer, there was just a loose variable hanging around.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. In that case, it would fall into the same bug that #14565 was intended to resolve. I think for now, it's probably better to use the separate mutator. If either RFC#70 gets fully implemented, or if the relax-to-TIR conversion provides a DeclBuffer annotation, then it could be updated.
| // replace the call with a call to an in-place PrimFunc. | ||
| // (Made public for testing.) | ||
| Call CreateInplaceCall(const Call& call, const Array<Integer>& inplace_indices) { | ||
| static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this use a separate legalization map? Even if they map to the same function, the per-operator support could then be checked by seeing if "FLegalizeInPlace" is defined for an operator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's worth thinking about. I didn't use it here because the operators that can be done in-place can also be implemented by applying the transformation given here to their legalized version. @tqchen, @MasterJH5574 (and others), do you think it would be worthwhile to add another map for in-place legalizations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. In that case, it probably makes sense to use FLegalize as-is. At some point in the future, the explicit listing of operators could be replaced by a TIR analysis pass. I think that (hypothetical future enhancement) would work especially well if this pass is then moved downstream of FuseTIR, so that it would be inspecting the already-generated TIR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like for a pass like this to come later in compilation, but the issue is the fact it relies on working within single dataflow blocks, so it has to come before we get rid of dataflow blocks. That is also causing me a few issues with thinking about how to handle in-place split and concat.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably a larger change, but would it make sense to either (A) implement it in terms of bindings without requiring a dataflow block, or (B) implement the pass as Sequential(ConvertToDataFlow(), current_pass(), ToNonDataflow())?
(B) would be easier to implement, but the use of ToNonDataflow as a post-processor would require that it be after any passes that require dataflows to be present.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't want to have the dataflow requirement, we would also need to handle control flow. That's the main issue and there's no obvious way to deal with control flow without going with a solution more like abstract interpretation (see my comment about the general-purpose liveness analysis PR).
We could, I suppose, fake the requirement be analyzing the blocks before, after, and within control flow. It would be a large change but it could be done. The result would be less general than a solution that properly accounts for control flow but it would be a start.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using the sequential transformation would be a less ad hoc version of the above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, and agreed that the sequential transform would be the better way to handle it, and a more consistent pattern that could be applied to other passes.
| vm = relax.VirtualMachine(ex, tvm.cpu()) | ||
| res = vm["main"](x, y) | ||
| assert (expected == res.numpy()).all() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want a pytest.xfail case for internal functions, to indicate that they are not yet supported? I'm thinking of cases like the following, where the current local analysis would conclude that Module.func cannot do in-place operations, but non-local analysis could perform Module.func in-place for some cases, depending on the liveness analysis in the calling scope.
@I.ir_module
class Module:
@R.function
def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
cls = Module
with R.dataflow():
# Cannot be in-place, uses a function argument
y = cls.subroutine(x)
# Can be in-place, uses a local binding
z = cls.subroutine(y)
R.output(z)
return z
@R.function(private=True)
def subroutine(x: R.Tensor((), "int32"), y: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
with R.dataflow():
z = R.add(x,R.const(1, "int32"))
R.output(z)
return zThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm always in favor of testing error conditions, so I'll add the test case too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! Though, I'd probably see this one as a not-yet-supported test case (pytest.mark.xfail), rather than a test of an error case (with pytest.raises(...)). Rather than verifying that a specific error is raised, it would document what behavior isn't covered by the current implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought, I'm not sure we should add an error-condition test case for something we don't even claim to support (the description of the pass does not reference trying to do function calls in place).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could go either way on it. The high-level name DataflowUseInplaceCalls doesn't specify that the InplaceCalls only refer to native relax operators, and excludes potentially-in-place calls to other relax functions. The docstring does specify operators, so you're right that we don't claim support for it there.
I'd see a @pytest.mark.xfail(reason="Not currently supported") as a way to explicitly disclaim responsibility for that use case. It would indicate what that support would look like, and that it could be supported at some point, but that it isn't currently supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can appreciate your point of view here, but I would point out that while tests are a form of documentation, documentation (such as doc comments) is also a form of documentation 🙂
If there is a strong need to reinforce the point that we are only looking for a small number of operators, we can do this, but I don't think there is right now, especially since handling Relax functions in-place is not something we have even planned on doing. (Possible in principle, but it's not on the horizon.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, and mostly just nitpicks on my side. While I think documenting potential future expansions is useful, I don't think it's worth delaying this PR. (Approved!)
…ns can be noted later)
ebc2d77 to
3107ce4
Compare
|
Thank you for drawing my attention to the dynamic case, @Lunderberg, I would not have guessed that it would "just work" like that 😄 |
Lunderberg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you on the changes, and I like the functionality!
Pursuant to issue #15319, this pass implements an in-place transformation for dataflow blocks. For each operator call in the dataflow block, the pass if at least one argument is not live past the call and has no aliases that are live past the call. If this condition is met for at least one argument, the operator can be replaced with an in-place version called via
call_tir_inplace. The in-place version is a TIRPrimFuncthat is produced by taking the legalizer for that operator and replacing its output with one of the inputs.The liveness analysis is very simple since we are handling only single dataflow blocks. Handling general Relax control flow would require a solution more akin to #15689. The alias analysis, by contrast, is more complex since alias analysis is generally a tricky issue. It focuses on being able to handle common cases (especially tuples).