-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relax][Transform] Provide callback versions of LazyTransformParams #16798
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relax][Transform] Provide callback versions of LazyTransformParams #16798
Conversation
This commit introduces two related utilities, `tir.analysis.is_pure_function` and `tir.analysis.assert_pure_function`. In contrast to the existing `tvm::tir::SideEffect`, which checks for side effects on a for a `PrimExpr`, `is_pure_function` checks for side effects for the function as a whole.
Prior to this commit, while expressions of type `DataType::Int(64)` could be computed in the `relax.transform.VMShapeLower`, expressions of any other type could not. This commit introduces `relax.transform.ComputePrimValue`, which produces `PrimFunc` subroutines to compute `PrimExpr` values of any dtype. This functionality will allow boolean values to be computed based on the symbolic values known at runtime.
Prior to this commit, the condition used for `relax::If` node and the
`"relax.assert_op"` operator was required to be a scalar tensor. This
made it difficult to alter behavior based on a runtime shape
parameter. For example, delegating to a vectorized implementation
based on a whether a tensor shape is divisible by the vector size.
This commit adds support for expressions of type `R.Prim('bool')` as
the conditional for `relax::If` and `"relax.assert_op"`, to allow
these use cases.
Prior to this commit, the `LazyTransformParams` function could be used to load model parameters on demand. However, the function used to load or set parameters needed to be registered within the global registry of `PackedFunc`s. This PR provides `LazyGetInput` and `LazySetOutput` transforms, which perform the lazy-loading through a `R.Callable` callback argument, rather than through a globally-registered `PackedFunc`.
|
This PR is currently marked as a draft, as the unit tests depend on functionality introduced in #16642. |
If `fget_param` accepts the parameter index first, and the parameter name second, then an implementation with signauture and default values of `def fget_param(index: int, name: Optional[str]=None)` could be used as either the callback of `LazyGetInput`, or as the globally-registered `"get_item"` for the existing `LazyTransformParams`, which should make it easier to transition between the two.
|
The prerequisite #16642 has now landed, so this PR is ready for review. |
|
Odd, the changes in this PR are still showing those that I presume were introduced in #16642. Maybe it's a Github thing. |
slyubomirsky
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this change. Most of it was not hard to follow, though the indexing in one case raised an eyebrow for me. I think the pass descriptions should mention how the callback is supposed to work. We should probably also have small end-to-end tests that constructs examples of callbacks and uses them (the callbacks could be practically no-ops, just enough to verify that it works).
It might also be good to have an occasional line in the transformation passes explaining what is being constructed, though I was able to follow it.
| // Pass LazyTransformParams() { | ||
| // auto pass_func = [](Function func, IRModule, PassContext) -> Function { | ||
| // LazyInput mutator; | ||
| // return Downcast<Function>(mutator(func)); | ||
| // }; | ||
| // return CreateFunctionPass(/*pass_function=*/pass_func, | ||
| // /*opt_level=*/0, | ||
| // /*pass_name=*/"MutateOpsForTraining", | ||
| // /*required=*/{}); | ||
| // } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I doubt we want to keep this much commented-out code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and removing. My intent is to replace the existing LazyTransformParams to be in terms of the C++ implementation, but didn't get around to it.
|
|
||
|
|
||
| def LazyGetInput() -> tvm.ir.transform.Pass: | ||
| """A pass that requests inputs lazily |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be preferable to have more detail in the doc comment as to what exactly the output of the pass will look like (the same is true for LazySetOutput too).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, and I've updated the docstrings for each.
slyubomirsky
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for addressing my comments. It would still be good to have some end-to-end tests that show what an example of such a callback would look like, if you're willing to do it.
|
Thank you, and good point on the end-to-end examples. I'm planning to upstream a use case, where the callback would load data from a |
Prior to this commit, the
LazyTransformParamsfunction could be used to load model parameters on demand. However, the function used to load or set parameters needed to be registered within the global registry ofPackedFuncs. This PR providesLazyGetInputandLazySetOutputtransforms, which perform the lazy-loading through aR.Callablecallback argument, rather than through a globally-registeredPackedFunc.