-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Unity] Implement FNormalize attribute for operators #16067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Unity] Implement FNormalize attribute for operators #16067
Conversation
82211e5 to
6b6a185
Compare
Some Relax operators have requirements regarding their AST that are
stronger than are checked by the C++ types being used. These are
similar to checks that are present in the `tvm::relax::WellFormed`
utility, such as checks forbidding the use of undefined variables,
which are also stronger than required by the underlying C++ types.
However, because every operator may have unique requirements, it would
be unreasonable to expect a writer of a `relax::ExprMutator` to be
aware of and to maintain all such requirements.
This PR introduces an operation operator attribute `FNormalize`. If
defined, this function is used to apply an operator-specific
normalization.
* If no change is required, `FNormalize` should return the input
argument unmodified.
* `FNormalize` is only responsible for normalization of the operator
itself. The expression it returns may be unnormalized (e.g. contain
nested expressions).
* `FNormalize` receives the `BlockBuilder` as an argument, to allow
context-dependent normalization.
For example, an operator whose normalization requires in-line
expressions may use `BlockBuilder::LookupBinding` to perform
variable replacement.
* `FNormalize` is applied after `FInferStructInfo`. `FNormalize` may
assume that the `relax::Call` passed to `FNormalize` has
well-defined struct info.
* Corollary: `FInferStructInfo` may not assume that its
`relax::Call` argument has been passed through `FNormalize`.
This is a reasonable requirement, because (1) shape inference
should depend only on the struct info of arguments and not the
values themselves, and (2) this only impacts operators that use
`FNormalize`.
* `FNormalize` should not be used to apply simplifications, and should
be limited to cases where the same computation may be expressed in
multiple manners.
For example, replacing a by-variable tuple with an in-line tuple in
`R.call_tir` is a form of normalization, but replacing `R.add(arg,
R.const(0))` with `arg` is a form of simplification.
This separation is to ensure that `FNormalize` has minimal overhead,
as some simplifications may have large computational costs, and
`FNormalize` is applied as part of all `ExprMutator` usage. A later
PR will introduce an attribute `FSimplify`, along with a dedicated
pass to apply simplifications.
* Use of `FNormalize` is suppressed while parsing TVMScript.
TVMScript must be able to generate test cases that trigger specific
failure modes, and that may include producing un-normalized relax
IR. In addition, TVMScript must be stable when passed through a
round-trip from IR to text to IR.
6b6a185 to
f4ec8a3
Compare
|
Thanks for the proposed change. I like how FNormalize can help reducing overhead of creating certain operators and bring them back to normal form. I only have one comment on the wellform check side. In this case, it is useful to have an intentionally duplicated check that is different from FNormalize , e.g. have a |
|
Thank you, and I like the overall design. I think we still want to keep all the normalization logic in (Also, see the other comment for performance benchmarking.) |
|
after thinking a bit more, i now agree that we can reuse FNormalize in wellform check. thanks for proposing the change |
src/relax/ir/block_builder.cc
Outdated
| // How much opt could an opt op Op if an opt op could op opt? | ||
| if (auto opt_op = op->op.as<Op>()) { | ||
| auto op = opt_op.value(); | ||
| if (apply_f_normalize_ && op_map_normalize_.count(op)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use this function https://github.com/apache/tvm/blob/main/include/tvm/ir/op.h#L476
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and updated! I had checked for a single-parameter .get, and an iterator-style .find, but hadn't found the two-parameter .get.
Updated to use if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr), here and in well_formed.cc.
src/relax/analysis/well_formed.cc
Outdated
| // case it produced a nested expression. | ||
|
|
||
| if (auto opt_op = call->op.as<Op>()) { | ||
| auto op = opt_op.value(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/apache/tvm/blob/main/include/tvm/ir/op.h#L476 we can directly use this function to simplofy the logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_map_normalize_.get(call->op, nullptr)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and updated to use if (auto func_normalize = op_map_normalize_.get(call->op, nullptr); func_normalize != nullptr).
|
Thank you, and changes made as suggested! |
Some Relax operators have requirements regarding their AST that are stronger than are checked by the C++ types being used. These are similar to checks that are present in the
tvm::relax::WellFormedutility, such as checks forbidding the use of undefined variables, which are also stronger than required by the underlying C++ types. However, because every operator may have unique requirements, it would be unreasonable to expect a writer of arelax::ExprMutatorto be aware of and to maintain all such requirements.This PR introduces an operation operator attribute
FNormalize. If defined, this function is used to apply an operator-specific normalization. The implementation ofFNormalizehas the following design decisions.If no change is required,
FNormalizeshould return the input argument unmodified.FNormalizeis only responsible for normalization of the operator itself. The expression it returns may be unnormalized (e.g. contain nested expressions).FNormalizereceives theBlockBuilderas an argument, to allow context-dependent normalization.For example, an operator whose normalization requires in-line expressions may use
BlockBuilder::LookupBindingto perform variable replacement.FNormalizeis applied afterFInferStructInfo.FNormalizemay assume that therelax::Callpassed toFNormalizehas well-defined struct info.Corollary:
FInferStructInfomay not assume that itsrelax::Callargument has been passed throughFNormalize.This is a reasonable requirement, because (1) shape inference should depend only on the struct info of arguments and not the values themselves, and (2) this only impacts operators that use
FNormalize.FNormalizeshould not be used to apply simplifications, and should be limited to cases where the same computation may be expressed in multiple manners.For example, replacing a by-variable tuple with an in-line tuple in
R.call_tiris a form of normalization, but replacingR.add(arg, R.const(0))withargis a form of simplification.This separation is to ensure that
FNormalizehas minimal overhead, as some simplifications may have large computational costs, andFNormalizeis applied as part of allExprMutatorusage. A later PR will introduce an attributeFSimplify, along with a dedicated pass to apply simplifications.Use of
FNormalizeis suppressed while parsing TVMScript. TVMScript must be able to generate test cases that trigger specific failure modes, and that may include producing un-normalized relax IR. In addition, TVMScript must be stable when passed through a round-trip from IR to text to IR.If an
IRModulecontains any non-normalized operators, theIRModuleis ill-formed. That is, allFNormalizeoperations on a well-formed module are no-ops.