[Relax] Implement relax.transform.RemoveSymbolicExpressionsInSubroutine #17080
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is a follow-up commit to
#16637, which updated
relax.transform.FuseOpsto provide additional parameters defining symbolic variables required by the fused functions. While this ensures thatrelax.transform.FuseOpsproduces well-formed Relax functions, these additional arguments can break some kernel implementations.This commit implements a new transform
RemoveSymbolicExpressionsInSubroutineto resolve this issue. This transform identifies function arguments whose sole purpose is to compute a symbolic expression, when that symbolic expression could be inferred from tensor shapes.For example, consider the following Relax function:
The
datatensor may be used to inferhidden_size, but cannot be used to inferbatch_sizeorseq_len. TheR.Shapeparameter exists solely to definebatch_sizeandseq_len, since all symbolic variables must be defined. However, neitherbatch_sizenorseq_lenare ever used outside of the expressionbatch_size * seq_len, and the value ofbatch_size * seq_lencould be inferred from the shape of thedatatensor.This new transform identifies cases where an argument is otherwise unnecessary, and replaces the symbolic expression with a new argument. This makes the
dummy_arg: R.Shapebe entirely unused, so a later use ofrelax.transform.RemoveUnusedParameters()can remove the parameter altogether.