Skip to content

Conversation

@Lunderberg
Copy link
Contributor

Prior to this commit, the condition used for relax::If node and the "relax.assert_op" operator was required to be a scalar tensor. This made it difficult to alter behavior based on a runtime shape parameter. For example, delegating to a vectorized implementation based on a whether a tensor shape is divisible by the vector size.

This commit adds support for expressions of type R.Prim('bool') as the conditional for relax::If and "relax.assert_op", to allow these use cases.



def ComputePrimValue() -> tvm.ir.transform.Pass:
"""Compute all R.prim_value instances
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this description should be more precise. I assume it's supposed to come late in the phase ordering since it inserts direct calls to PrimFuncs? (And so should probably come after we end purity checking?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point on improving the docstring.

Regarding phase ordering, I don’t think we need to restrict its usage. The calls to PrimFunc instances are valid in user-provided Relax functions, so this could occur early in the phase ordering. The only limitation is that it must occur before VMShapeLower, as VMShapeLower expects all R.prim_value(arg) expressions to have int64 arguments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but inserting the PrimFunc calls will likely change the purity of the functions where that happens. call_tir could be used to avoid that but then that will require using this before lowering call_tir.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I think this is another argument in favor of allowing FuncStructInfo annotations for PrimFunc objects, as that would allow the generated PrimFunc instances to be marked as pure functions. I'll add a unit test to see how well that works for maintaining purity tracking when calling a pure PrimFunc from a pure Relax function.

Copy link
Contributor

@slyubomirsky slyubomirsky Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are they truly pure? No modification of external values?

Edit: Yeah, they just use a return value. I imagine this means that we actually have to check the bodies of PrimFuncs to determine if they're pure and also give users the option to override the automatic judgment. The rules for that can be very simple: Consider it impure if there is any write to a tensor or external call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking an even simpler heuristic: All PrimFuncs are impure, unless explicitly annotated otherwise. In this case, since the functions are being generated in a manner that requires purity, it could also provide the annotation.

For long-term, agreed, it would be good to have the TIR-level purity analysis. I think I'd weaken the condition you mentioned slightly: A function is impure if it writes to a buffer that it didn't itself allocate.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'm fine with requiring an annotation since the great majority of PrimFuncs are going to be impure.

}

auto ret_dtype = node->value->dtype;
auto param_vars = tir::UndefinedVars(node->value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this call know which TIR vars are in scope per the Relax scoping rules?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call isn’t aware of the Relax scoping rules, but I don’t think there’s a benefit of checking it at this point. Any well-formed input that only uses in-scope TIR variables would produce well-formed output. Any ill-formed input that uses out-of-scope TIR variables would produce ill-formed output that still uses the out-of-scope TIR variables.

Validating the relax scoping rules at this point would require additional tracking the in-scope variables, which would duplicate the functionality of the well-formed checker. Since this pass wouldn’t make any ill-formed usage worse (and therefore harder to debug), I don’t think it’s worth duplicating the in-scope tracking here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think you're right that this would still work out just fine in that case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably overkill, but I think it ended up being simpler to just generate the appropriate FuncStructInfo for a PrimFunc on construction, and to populate the currently-empty struct_info_ field. This includes inspecting the body to see if the PrimFunc is pure.

Comment on lines +98 to +100
@T.prim_func(private=True)
def compute_symbolic_expr(N: T.int64, M: T.int64) -> T.int64:
T.ret(N * M)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As written, would this roundtrip? It seems you're manually setting the purity in the definition, but how would the parser know this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. The test case only passes right now because the PrimFuncNode::struct_info_, if it exists, isn't checked for structural equality.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that purity checking is implemented for PrimFuncs, this will roundtrip, correct?

@slyubomirsky
Copy link
Contributor

slyubomirsky commented Mar 1, 2024

Excellent, I'm glad to see an analysis for PrimFunc purity. Do you think it's necessary to have an override to assert a PrimFunc is pure? It's doubtful that we would need it right now but it might be a good thing to have in the back of our pocket (i.e., an escape hatch). I think it could be an analogue to force_pure in Relax (a special attribute). It can be the case that an operator has effects for some inputs but the programmer might know that in this case it doesn't, or other situations can arise.

I will update both the TIR and Relax specs to take note of purity checking for PrimFuncs ^_^

@Lunderberg
Copy link
Contributor Author

Do you think it's necessary to have an override to assert a PrimFunc is pure?

I think I prefer having the overrides at an individual operator level, rather than on the function as a whole. So, a user could use T.call_pure_packed instead of T.call_packed, in order to make the overall function be marked as pure. Especially when inlining/extracting regions of a function (e.g. SplitHostDevice), the top-down annotations can be ambiguous. Relying on bottom-up annotations allows each extracted region to have correct annotations, without requiring the developer to explicitly propagate function-level annotations.

@slyubomirsky
Copy link
Contributor

That's a good idea. What would be the analogue to such annotations in a PrimFunc? Do we think we need it?

@Lunderberg
Copy link
Contributor Author

That's a good idea. What would be the analogue to such annotations in a PrimFunc? Do we think we need it?

(Apologies, missed the notification on this one.) For a PrimFunc, it already has a distinction between T.call_extern and T.call_pure_extern. Looks like a PrimFunc only has T.call_packed, and doesn't have a corresponding T.call_pure_packed. However, as these are largely for internal functions, where the packed interface isn't necessary, and T.call_pure_extern would be sufficient.

@Lunderberg
Copy link
Contributor Author

This PR needs an additional test case before merging. On the main branch, the following IRModule is valid. However, with this PR, it fails during parsing.

@I.ir_module
class Module:
    @R.function
    def main(A: R.Tensor) -> R.Prim("bool"):
        return Module.is_bfloat16_dtype(A)

    @T.prim_func(private=True)
    def is_bfloat16_dtype(tensor: T.handle) -> T.bool:
        T.func_attr({"tir.is_scheduled": True, "tir.is_host_func": True})

        # From #include <tvm/tir/builtin.h>
        kArrTypeCode = T.meta_var(5)
        kArrTypeBits = T.meta_var(6)
        kArrTypeLanes = T.meta_var(7)

        # From #include <dlpack/dlpack.h>
        kDLBfloat = T.meta_var(4)

        type_code = T.tvm_struct_get(tensor, 0, kArrTypeCode, dtype="uint8")
        type_bits = T.tvm_struct_get(tensor, 0, kArrTypeBits, dtype="uint8")
        type_lanes = T.tvm_struct_get(tensor, 0, kArrTypeLanes, dtype="uint16")

        is_bfloat16: T.bool = (
            (type_code == kDLBfloat) and (type_bits == 16) and (type_lanes == 1)
        )
        return is_bfloat16

The failure occurs due to the inferred struct info of R.Callable([R.Prim("handle")], R.Prim("bool")). When the function is passed a R.Tensor, it fails relax's type check as R.Tensor is not compatible with R.Prim("handle").

I think the fix will be to update StructInfoBaseCheck to recognize that R.Prim("handle") is a valid TIR representation of a DLTensor*, but I want to think on it first.

@Lunderberg Lunderberg force-pushed the relax_allow_prim_bool_as_condition branch from d3d1b23 to eedc8b0 Compare March 12, 2024 18:59
@Lunderberg
Copy link
Contributor Author

The additional test case has been added, and fixed, so this PR is ready for review/merge.

return relax::TensorStructInfo(shape, buf->dtype);
}

if (auto prim_type = param->type_annotation.as<PrimTypeNode>();
Copy link
Contributor

@slyubomirsky slyubomirsky Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By my understanding from the TIR spec, some Handle-typed vars sound like the would correspond more to Objects in Relax. I.e., void-typed values that are completely opaque. Are we sure this shouldn't be the case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good call, and I've updated it to provide R.Object instead of R.Tensor in this case. My main goal with this conditional was to avoid having DLTensor* parameters be labeled with R.Prim(dtype='handle'), as that would be incompatible with R.Tensor arguments.

In the future, the annotations could be improved by inspecting the PrimFunc body. If the handle-typed tir.Var is used in a context that requires a DLTensor* argument (typically builtin::tvm_struct_get), then the annotation could be improved from R.Object to R.Tensor. That would allow better type-checking at the Relax callsite, but isn't necessary for a first implementation.

@slyubomirsky
Copy link
Contributor

Thank you for your changes, I am very excited to have purity checking and more accurate StructInfo for PrimFuncs. I had one question remaining about the inference for argument types, see my comment on those lines.

@Lunderberg Lunderberg force-pushed the relax_allow_prim_bool_as_condition branch from 8e2d599 to c17685d Compare March 22, 2024 15:21
This commit introduces two related utilities,
`tir.analysis.is_pure_function` and `tir.analysis.assert_pure_function`.
In contrast to the existing `tvm::tir::SideEffect`, which checks for
side effects on a for a `PrimExpr`, `is_pure_function` checks for side
effects for the function as a whole.
Prior to this commit, while expressions of type `DataType::Int(64)`
could be computed in the `relax.transform.VMShapeLower`, expressions
of any other type could not.  This commit introduces
`relax.transform.ComputePrimValue`, which produces `PrimFunc`
subroutines to compute `PrimExpr` values of any dtype.

This functionality will allow boolean values to be computed based on
the symbolic values known at runtime.
Prior to this commit, the condition used for `relax::If` node and the
`"relax.assert_op"` operator was required to be a scalar tensor.  This
made it difficult to alter behavior based on a runtime shape
parameter.  For example, delegating to a vectorized implementation
based on a whether a tensor shape is divisible by the vector size.

This commit adds support for expressions of type `R.Prim('bool')` as
the conditional for `relax::If` and `"relax.assert_op"`, to allow
these use cases.
@Lunderberg Lunderberg force-pushed the relax_allow_prim_bool_as_condition branch from c17685d to f554738 Compare March 26, 2024 15:41
@Lunderberg
Copy link
Contributor Author

Rebased onto main to resolve a merge conflict.

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for responding to feedback. I am pleased to see the improved type-checking for PrimFuncs in Relax.

@Lunderberg Lunderberg merged commit eb5458e into apache:main Mar 28, 2024
@Lunderberg Lunderberg deleted the relax_allow_prim_bool_as_condition branch March 28, 2024 21:53
@quic-sanirudh
Copy link
Contributor

quic-sanirudh commented Apr 2, 2024

@Lunderberg After this pass introduced support for computing struct_info for PrimFuncs, I'm seeing a case where a primfunc whose output buffer is modified with the sch.transform_layout primitive and added back to the mod (similart to what happens in the legalization of R.layout_transform) causes a failure later in a different pass (call_tir_rewrite pass in my case).

A small test case that reproduces the issue is something like below:

import tvm
from tvm.script import relax as R, tir as T, ir as I
from tvm import tir, relax

@I.ir_module
class Before:
    @R.function
    def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32")) -> R.Tensor((16,), dtype="float32"):
        with R.dataflow():
            lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, index_map=lambda i: (i // 4, i % 4), pad_value=None)
            gv: R.Tensor((4, 4), dtype="float32") = lv
            R.output(gv)
        return gv

if __name__ == '__main__':
    mod = Before
    mod = relax.transform.LegalizeOps()(mod)
    mod = relax.transform.CallTIRRewrite()(mod)

From what I've been able to understand, the FuncStructInfo for the PrimFunc is not recomputed when it's output buffer got changed, which happens through a schedule primitive in the case of the R.layout_transform legalization. I can try to fix it by trying to identify cases where a prim_func can change its struct_info through any means (scheduling in this case) and call InferStructInfo in all those cases, but I wanted to ask if there was any better solutions here.

@Lunderberg
Copy link
Contributor Author

Ooh, interesting. I wasn't aware that R.layout_transform was implemented in terms of Schedule.layout_transform. Can you take a look at #16832 and see if it resolves your issue?

thaisacs pushed a commit to thaisacs/tvm that referenced this pull request Apr 3, 2024
* [TIR][Analysis] Implemented tir.analysis.is_pure_function

This commit introduces two related utilities,
`tir.analysis.is_pure_function` and `tir.analysis.assert_pure_function`.
In contrast to the existing `tvm::tir::SideEffect`, which checks for
side effects on a for a `PrimExpr`, `is_pure_function` checks for side
effects for the function as a whole.

* [Transform] Implement relax.transform.ComputePrimValue

Prior to this commit, while expressions of type `DataType::Int(64)`
could be computed in the `relax.transform.VMShapeLower`, expressions
of any other type could not.  This commit introduces
`relax.transform.ComputePrimValue`, which produces `PrimFunc`
subroutines to compute `PrimExpr` values of any dtype.

This functionality will allow boolean values to be computed based on
the symbolic values known at runtime.

* [Relax] Allow R.Prim('bool') in relax::If and assert_op

Prior to this commit, the condition used for `relax::If` node and the
`"relax.assert_op"` operator was required to be a scalar tensor.  This
made it difficult to alter behavior based on a runtime shape
parameter.  For example, delegating to a vectorized implementation
based on a whether a tensor shape is divisible by the vector size.

This commit adds support for expressions of type `R.Prim('bool')` as
the conditional for `relax::If` and `"relax.assert_op"`, to allow
these use cases.

* Lint fix
@quic-sanirudh
Copy link
Contributor

Ooh, interesting. I wasn't aware that R.layout_transform was implemented in terms of Schedule.layout_transform. Can you take a look at #16832 and see if it resolves your issue?

Thanks a lot for the very quick fix. Yes it does seem to solve the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants