Skip to content

Conversation

@Lunderberg
Copy link
Contributor

Prior to this commit, the LazyTransformParams function could be used to load model parameters on demand. However, the function used to load or set parameters needed to be registered within the global registry of PackedFuncs. This PR provides LazyGetInput and LazySetOutput transforms, which perform the lazy-loading through a R.Callable callback argument, rather than through a globally-registered PackedFunc.

This commit introduces two related utilities,
`tir.analysis.is_pure_function` and `tir.analysis.assert_pure_function`.
In contrast to the existing `tvm::tir::SideEffect`, which checks for
side effects on a for a `PrimExpr`, `is_pure_function` checks for side
effects for the function as a whole.
Prior to this commit, while expressions of type `DataType::Int(64)`
could be computed in the `relax.transform.VMShapeLower`, expressions
of any other type could not.  This commit introduces
`relax.transform.ComputePrimValue`, which produces `PrimFunc`
subroutines to compute `PrimExpr` values of any dtype.

This functionality will allow boolean values to be computed based on
the symbolic values known at runtime.
Prior to this commit, the condition used for `relax::If` node and the
`"relax.assert_op"` operator was required to be a scalar tensor.  This
made it difficult to alter behavior based on a runtime shape
parameter.  For example, delegating to a vectorized implementation
based on a whether a tensor shape is divisible by the vector size.

This commit adds support for expressions of type `R.Prim('bool')` as
the conditional for `relax::If` and `"relax.assert_op"`, to allow
these use cases.
Prior to this commit, the `LazyTransformParams` function could be used
to load model parameters on demand.  However, the function used to
load or set parameters needed to be registered within the global
registry of `PackedFunc`s.  This PR provides `LazyGetInput` and
`LazySetOutput` transforms, which perform the lazy-loading through a
`R.Callable` callback argument, rather than through a
globally-registered `PackedFunc`.
@Lunderberg
Copy link
Contributor Author

This PR is currently marked as a draft, as the unit tests depend on functionality introduced in #16642.

If `fget_param` accepts the parameter index first, and the parameter
name second, then an implementation with signauture and default values
of `def fget_param(index: int, name: Optional[str]=None)` could be
used as either the callback of `LazyGetInput`, or as the
globally-registered `"get_item"` for the existing
`LazyTransformParams`, which should make it easier to transition
between the two.
@Lunderberg Lunderberg marked this pull request as ready for review March 28, 2024 21:54
@Lunderberg
Copy link
Contributor Author

The prerequisite #16642 has now landed, so this PR is ready for review.

@slyubomirsky
Copy link
Contributor

Odd, the changes in this PR are still showing those that I presume were introduced in #16642. Maybe it's a Github thing.

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this change. Most of it was not hard to follow, though the indexing in one case raised an eyebrow for me. I think the pass descriptions should mention how the callback is supposed to work. We should probably also have small end-to-end tests that constructs examples of callbacks and uses them (the callbacks could be practically no-ops, just enough to verify that it works).

It might also be good to have an occasional line in the transformation passes explaining what is being constructed, though I was able to follow it.

Comment on lines 264 to 273
// Pass LazyTransformParams() {
// auto pass_func = [](Function func, IRModule, PassContext) -> Function {
// LazyInput mutator;
// return Downcast<Function>(mutator(func));
// };
// return CreateFunctionPass(/*pass_function=*/pass_func,
// /*opt_level=*/0,
// /*pass_name=*/"MutateOpsForTraining",
// /*required=*/{});
// }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt we want to keep this much commented-out code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, and removing. My intent is to replace the existing LazyTransformParams to be in terms of the C++ implementation, but didn't get around to it.



def LazyGetInput() -> tvm.ir.transform.Pass:
"""A pass that requests inputs lazily
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be preferable to have more detail in the doc comment as to what exactly the output of the pass will look like (the same is true for LazySetOutput too).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, and I've updated the docstrings for each.

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for addressing my comments. It would still be good to have some end-to-end tests that show what an example of such a callback would look like, if you're willing to do it.

@Lunderberg
Copy link
Contributor Author

Thank you, and good point on the end-to-end examples. I'm planning to upstream a use case, where the callback would load data from a .safetensors file, and that would be a good place to add a practical example overall.

@Lunderberg Lunderberg merged commit 61249b4 into apache:main Apr 3, 2024
@Lunderberg Lunderberg deleted the relax_callback_lazy_transform_params branch April 3, 2024 16:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants