-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relax] Allow R.Prim('bool') in relax::If and assert_op #16642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relax] Allow R.Prim('bool') in relax::If and assert_op #16642
Conversation
|
|
||
|
|
||
| def ComputePrimValue() -> tvm.ir.transform.Pass: | ||
| """Compute all R.prim_value instances |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this description should be more precise. I assume it's supposed to come late in the phase ordering since it inserts direct calls to PrimFuncs? (And so should probably come after we end purity checking?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point on improving the docstring.
Regarding phase ordering, I don’t think we need to restrict its usage. The calls to PrimFunc instances are valid in user-provided Relax functions, so this could occur early in the phase ordering. The only limitation is that it must occur before VMShapeLower, as VMShapeLower expects all R.prim_value(arg) expressions to have int64 arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but inserting the PrimFunc calls will likely change the purity of the functions where that happens. call_tir could be used to avoid that but then that will require using this before lowering call_tir.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm. I think this is another argument in favor of allowing FuncStructInfo annotations for PrimFunc objects, as that would allow the generated PrimFunc instances to be marked as pure functions. I'll add a unit test to see how well that works for maintaining purity tracking when calling a pure PrimFunc from a pure Relax function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are they truly pure? No modification of external values?
Edit: Yeah, they just use a return value. I imagine this means that we actually have to check the bodies of PrimFuncs to determine if they're pure and also give users the option to override the automatic judgment. The rules for that can be very simple: Consider it impure if there is any write to a tensor or external call.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking an even simpler heuristic: All PrimFuncs are impure, unless explicitly annotated otherwise. In this case, since the functions are being generated in a manner that requires purity, it could also provide the annotation.
For long-term, agreed, it would be good to have the TIR-level purity analysis. I think I'd weaken the condition you mentioned slightly: A function is impure if it writes to a buffer that it didn't itself allocate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I'm fine with requiring an annotation since the great majority of PrimFuncs are going to be impure.
| } | ||
|
|
||
| auto ret_dtype = node->value->dtype; | ||
| auto param_vars = tir::UndefinedVars(node->value); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this call know which TIR vars are in scope per the Relax scoping rules?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This call isn’t aware of the Relax scoping rules, but I don’t think there’s a benefit of checking it at this point. Any well-formed input that only uses in-scope TIR variables would produce well-formed output. Any ill-formed input that uses out-of-scope TIR variables would produce ill-formed output that still uses the out-of-scope TIR variables.
Validating the relax scoping rules at this point would require additional tracking the in-scope variables, which would duplicate the functionality of the well-formed checker. Since this pass wouldn’t make any ill-formed usage worse (and therefore harder to debug), I don’t think it’s worth duplicating the in-scope tracking here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think you're right that this would still work out just fine in that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably overkill, but I think it ended up being simpler to just generate the appropriate FuncStructInfo for a PrimFunc on construction, and to populate the currently-empty struct_info_ field. This includes inspecting the body to see if the PrimFunc is pure.
| @T.prim_func(private=True) | ||
| def compute_symbolic_expr(N: T.int64, M: T.int64) -> T.int64: | ||
| T.ret(N * M) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As written, would this roundtrip? It seems you're manually setting the purity in the definition, but how would the parser know this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. The test case only passes right now because the PrimFuncNode::struct_info_, if it exists, isn't checked for structural equality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that purity checking is implemented for PrimFuncs, this will roundtrip, correct?
|
Excellent, I'm glad to see an analysis for PrimFunc purity. Do you think it's necessary to have an override to assert a PrimFunc is pure? It's doubtful that we would need it right now but it might be a good thing to have in the back of our pocket (i.e., an escape hatch). I think it could be an analogue to I will update both the TIR and Relax specs to take note of purity checking for PrimFuncs ^_^ |
I think I prefer having the overrides at an individual operator level, rather than on the function as a whole. So, a user could use |
|
That's a good idea. What would be the analogue to such annotations in a PrimFunc? Do we think we need it? |
(Apologies, missed the notification on this one.) For a |
|
This PR needs an additional test case before merging. On the main branch, the following IRModule is valid. However, with this PR, it fails during parsing. @I.ir_module
class Module:
@R.function
def main(A: R.Tensor) -> R.Prim("bool"):
return Module.is_bfloat16_dtype(A)
@T.prim_func(private=True)
def is_bfloat16_dtype(tensor: T.handle) -> T.bool:
T.func_attr({"tir.is_scheduled": True, "tir.is_host_func": True})
# From #include <tvm/tir/builtin.h>
kArrTypeCode = T.meta_var(5)
kArrTypeBits = T.meta_var(6)
kArrTypeLanes = T.meta_var(7)
# From #include <dlpack/dlpack.h>
kDLBfloat = T.meta_var(4)
type_code = T.tvm_struct_get(tensor, 0, kArrTypeCode, dtype="uint8")
type_bits = T.tvm_struct_get(tensor, 0, kArrTypeBits, dtype="uint8")
type_lanes = T.tvm_struct_get(tensor, 0, kArrTypeLanes, dtype="uint16")
is_bfloat16: T.bool = (
(type_code == kDLBfloat) and (type_bits == 16) and (type_lanes == 1)
)
return is_bfloat16The failure occurs due to the inferred struct info of I think the fix will be to update |
d3d1b23 to
eedc8b0
Compare
|
The additional test case has been added, and fixed, so this PR is ready for review/merge. |
| return relax::TensorStructInfo(shape, buf->dtype); | ||
| } | ||
|
|
||
| if (auto prim_type = param->type_annotation.as<PrimTypeNode>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By my understanding from the TIR spec, some Handle-typed vars sound like the would correspond more to Objects in Relax. I.e., void-typed values that are completely opaque. Are we sure this shouldn't be the case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good call, and I've updated it to provide R.Object instead of R.Tensor in this case. My main goal with this conditional was to avoid having DLTensor* parameters be labeled with R.Prim(dtype='handle'), as that would be incompatible with R.Tensor arguments.
In the future, the annotations could be improved by inspecting the PrimFunc body. If the handle-typed tir.Var is used in a context that requires a DLTensor* argument (typically builtin::tvm_struct_get), then the annotation could be improved from R.Object to R.Tensor. That would allow better type-checking at the Relax callsite, but isn't necessary for a first implementation.
|
Thank you for your changes, I am very excited to have purity checking and more accurate StructInfo for PrimFuncs. I had one question remaining about the inference for argument types, see my comment on those lines. |
8e2d599 to
c17685d
Compare
This commit introduces two related utilities, `tir.analysis.is_pure_function` and `tir.analysis.assert_pure_function`. In contrast to the existing `tvm::tir::SideEffect`, which checks for side effects on a for a `PrimExpr`, `is_pure_function` checks for side effects for the function as a whole.
Prior to this commit, while expressions of type `DataType::Int(64)` could be computed in the `relax.transform.VMShapeLower`, expressions of any other type could not. This commit introduces `relax.transform.ComputePrimValue`, which produces `PrimFunc` subroutines to compute `PrimExpr` values of any dtype. This functionality will allow boolean values to be computed based on the symbolic values known at runtime.
Prior to this commit, the condition used for `relax::If` node and the
`"relax.assert_op"` operator was required to be a scalar tensor. This
made it difficult to alter behavior based on a runtime shape
parameter. For example, delegating to a vectorized implementation
based on a whether a tensor shape is divisible by the vector size.
This commit adds support for expressions of type `R.Prim('bool')` as
the conditional for `relax::If` and `"relax.assert_op"`, to allow
these use cases.
c17685d to
f554738
Compare
|
Rebased onto main to resolve a merge conflict. |
slyubomirsky
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for responding to feedback. I am pleased to see the improved type-checking for PrimFuncs in Relax.
|
@Lunderberg After this pass introduced support for computing struct_info for PrimFuncs, I'm seeing a case where a primfunc whose output buffer is modified with the A small test case that reproduces the issue is something like below: import tvm
from tvm.script import relax as R, tir as T, ir as I
from tvm import tir, relax
@I.ir_module
class Before:
@R.function
def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
with R.dataflow():
lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None)
gv: R.Tensor((4, 4), dtype="float32") = lv
R.output(gv)
return gv
if __name__ == '__main__':
mod = Before
mod = relax.transform.LegalizeOps()(mod)
mod = relax.transform.CallTIRRewrite()(mod)From what I've been able to understand, the |
|
Ooh, interesting. I wasn't aware that |
* [TIR][Analysis] Implemented tir.analysis.is_pure_function
This commit introduces two related utilities,
`tir.analysis.is_pure_function` and `tir.analysis.assert_pure_function`.
In contrast to the existing `tvm::tir::SideEffect`, which checks for
side effects on a for a `PrimExpr`, `is_pure_function` checks for side
effects for the function as a whole.
* [Transform] Implement relax.transform.ComputePrimValue
Prior to this commit, while expressions of type `DataType::Int(64)`
could be computed in the `relax.transform.VMShapeLower`, expressions
of any other type could not. This commit introduces
`relax.transform.ComputePrimValue`, which produces `PrimFunc`
subroutines to compute `PrimExpr` values of any dtype.
This functionality will allow boolean values to be computed based on
the symbolic values known at runtime.
* [Relax] Allow R.Prim('bool') in relax::If and assert_op
Prior to this commit, the condition used for `relax::If` node and the
`"relax.assert_op"` operator was required to be a scalar tensor. This
made it difficult to alter behavior based on a runtime shape
parameter. For example, delegating to a vectorized implementation
based on a whether a tensor shape is divisible by the vector size.
This commit adds support for expressions of type `R.Prim('bool')` as
the conditional for `relax::If` and `"relax.assert_op"`, to allow
these use cases.
* Lint fix
Thanks a lot for the very quick fix. Yes it does seem to solve the issue. |
Prior to this commit, the condition used for
relax::Ifnode and the"relax.assert_op"operator was required to be a scalar tensor. This made it difficult to alter behavior based on a runtime shape parameter. For example, delegating to a vectorized implementation based on a whether a tensor shape is divisible by the vector size.This commit adds support for expressions of type
R.Prim('bool')as the conditional forrelax::Ifand"relax.assert_op", to allow these use cases.