-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Unity] Provide FuncStructInfo from bb.emit_te
#15026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Lunderberg
wants to merge
5
commits into
apache:unity
from
Lunderberg:relax_funcstructinfo_from_call_te
Closed
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
ccf665c
[Unity] Provide FuncStructInfo from `bb.emit_te`
Lunderberg 0cc5f49
Lint fix
Lunderberg 58e7a3c
Import TupleStructInfo, maintain return type for output_sinfo
Lunderberg 6e0777f
Use FuncStructInfo for output-passing style
Lunderberg 74e207b
Annotate non-tensor arguments with PrimStructInfo
Lunderberg File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen Is this consistent with how we want
FuncStructInfoto work? I thoughtPrimFuncs would beObjectStructInfo(this is what we wrote in the Relax spec). Perhaps they use aderive_funcinstead? If we want them to use ordinaryFuncStructInfo, does that also mean we'll allow them to be called outside ofcall_tir?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question, and I had assumed this was intended, but would be interested to hear on it. I had mostly assumed that a Relax function and a TIR PrimFunc should expose the same information, so long as they have the same convention. That is, since the callsite has no distinction between a
GlobalVarrepresenting arelax::Functionor atir::PrimFunc, it seemed that thestruct_info_would depend only on the call sequence, and not the implementation dialect.Regarding
call_tir, I think the Relax-to-TIR calls are not restricted to theR.call_tirbuilt-in, because theLowerCallTIRpass can output arelax::CallNodewith aGlobalVaras its operation.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, permitting TIR calls outside of
call_tiris something we're trying to figure out with respect to phase ordering in Relax (see thread). I was under the impression that we did not want direct calls toPrimFuncs in the front end, so we should clarify that (we could put this on the agenda for a community meeting).FWIW, I don't think it would be hard to give
PrimFuncsFuncStructInfo, but there is the issue that they mutate their arguments, so they should be treated as impure (except when called viacall_tir).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point regarding the mutation. Thinking on it, I'm also not sure what the best
FuncStructInfowould be. It could reasonably be eitherFuncStructInfo(params = [*input_tensors, *output_tensors], ret=None), which matches the TIR function's signature, orFuncStructInfo(params=input_tensors, ret=relax.Tuple(output_tensors))`, which matches the exposed semantics in Relax.The original issue I was running into was that the result of
bb.emit_tedoesn't preserve the output struct information across mutations. If I have a TE function that accepts dynamic shapes, but which is called using static shapes, then the return type of therelax::Callshould be an inferred static shape. This works during the first usage of BlockBuilder, when a user is callingbb.emit_tedirectly. However, when the module is mutated, any mutation of the call node relies on therelax::Normalizerto regenerate the output struct info, and it doesn't have enough information to do so.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if you just call the
PrimFuncby itself, it will work by mutating the arguments, so the best signature would be the first one you suggested.call_tir(the operator) is what's responsible for providing the nice wrapper over the mutation.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Lunderberg I've put this PR on the agenda for next week's community meeting. If you can make it, that would be good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! I probably won't be able to attend, given the timing, but I agree that discussion would be good. For now, I've converted this PR to a draft, to ensure that it can't be merged prior to discussion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conclusion from the meeting: We think it's okay to permit direct calls to PrimFuncs as long as they're treated as impure and to give FuncStructInfo to PrimFuncs, though they should, again, be marked as impure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if unit (empty tuple) would make more sense as the return type, incidentally. Also, I do think the FuncStructInfo should have the purity set to false.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summarizing our conversation from this morning:
Shape propagation through
bb.emit_teonly works during the initial construction of a Relax module, when therelax::Call("relax.call_tir",...)node is explicitly typed. Re-derivation of the output shape is not implemented, and so the shape information can be lost during lowering if the arguments tocall_tirchange.Annotating a PrimFunc with
FuncStructInfoto represent the output ofcall_tir(i.e. pure function, tensor output) wouldn't be accurate, and could cause confusion in the future.Annotating a PrimFunc with
FuncStructInfoto represent the PrimFunc itself (i.e. impure function, mutates arguments) would be accurate, but insufficient forcall_tirto propagate shapes, as input/output shapes are mixed.Would be useful to have a purity annotations for each parameter, dividing arguments into read-only, output, and mutate-in-place. This would allow a PrimFunc to be accurately annotated, and would be sufficient for
call_tirto identify outputs for shape propagation.