Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,16 @@
from . import _ffi_api
from .expr import Tuple as rx_Tuple
from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor
from .expr import _update_struct_info
from ..te import Tensor as te_Tensor, create_relax_prim_func
from ..ir import Array, Attrs, Type, Map
from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo
from .struct_info import (
PrimStructInfo,
ShapeStructInfo,
TensorStructInfo,
FuncStructInfo,
TupleStructInfo,
)


def metadata_partitioner(rx_txt: str) -> List[str]:
Expand Down Expand Up @@ -455,10 +462,19 @@ def _shape_with_old_tir_var(
# with old set of variables.
tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}

output_sinfo = [
TensorStructInfo(_shape_with_old_tir_var(out.shape, tir_var_inverse_map), out.dtype)
for out in outs
]
def te_to_sinfo(arg):
if isinstance(arg, tir.Var):
return PrimStructInfo(arg.dtype)
else:
return TensorStructInfo(
_shape_with_old_tir_var(arg.shape, tir_var_inverse_map), arg.dtype
)

input_sinfo = [te_to_sinfo(arg) for arg in [*te_args, *unbound_tir_vars]]
output_sinfo = [te_to_sinfo(out) for out in outs]

primfunc_sinfo = FuncStructInfo([*input_sinfo, *output_sinfo], PrimStructInfo("void"))
_update_struct_info(tir_func, primfunc_sinfo)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen Is this consistent with how we want FuncStructInfo to work? I thought PrimFuncs would be ObjectStructInfo (this is what we wrote in the Relax spec). Perhaps they use a derive_func instead? If we want them to use ordinary FuncStructInfo, does that also mean we'll allow them to be called outside of call_tir?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, and I had assumed this was intended, but would be interested to hear on it. I had mostly assumed that a Relax function and a TIR PrimFunc should expose the same information, so long as they have the same convention. That is, since the callsite has no distinction between a GlobalVar representing a relax::Function or a tir::PrimFunc, it seemed that the struct_info_ would depend only on the call sequence, and not the implementation dialect.

Regarding call_tir, I think the Relax-to-TIR calls are not restricted to the R.call_tir built-in, because the LowerCallTIR pass can output a relax::CallNode with a GlobalVar as its operation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, permitting TIR calls outside of call_tir is something we're trying to figure out with respect to phase ordering in Relax (see thread). I was under the impression that we did not want direct calls to PrimFuncs in the front end, so we should clarify that (we could put this on the agenda for a community meeting).

FWIW, I don't think it would be hard to give PrimFuncs FuncStructInfo, but there is the issue that they mutate their arguments, so they should be treated as impure (except when called via call_tir).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point regarding the mutation. Thinking on it, I'm also not sure what the best FuncStructInfo would be. It could reasonably be either FuncStructInfo(params = [*input_tensors, *output_tensors], ret=None), which matches the TIR function's signature, or FuncStructInfo(params=input_tensors, ret=relax.Tuple(output_tensors))`, which matches the exposed semantics in Relax.

The original issue I was running into was that the result of bb.emit_te doesn't preserve the output struct information across mutations. If I have a TE function that accepts dynamic shapes, but which is called using static shapes, then the return type of the relax::Call should be an inferred static shape. This works during the first usage of BlockBuilder, when a user is calling bb.emit_te directly. However, when the module is mutated, any mutation of the call node relies on the relax::Normalizer to regenerate the output struct info, and it doesn't have enough information to do so.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you just call the PrimFunc by itself, it will work by mutating the arguments, so the best signature would be the first one you suggested. call_tir (the operator) is what's responsible for providing the nice wrapper over the mutation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Lunderberg I've put this PR on the agenda for next week's community meeting. If you can make it, that would be good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I probably won't be able to attend, given the timing, but I agree that discussion would be good. For now, I've converted this PR to a draft, to ensure that it can't be merged prior to discussion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conclusion from the meeting: We think it's okay to permit direct calls to PrimFuncs as long as they're treated as impure and to give FuncStructInfo to PrimFuncs, though they should, again, be marked as impure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if unit (empty tuple) would make more sense as the return type, incidentally. Also, I do think the FuncStructInfo should have the purity set to false.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summarizing our conversation from this morning:

  • Shape propagation through bb.emit_te only works during the initial construction of a Relax module, when the relax::Call("relax.call_tir",...) node is explicitly typed. Re-derivation of the output shape is not implemented, and so the shape information can be lost during lowering if the arguments to call_tir change.

  • Annotating a PrimFunc with FuncStructInfo to represent the output of call_tir (i.e. pure function, tensor output) wouldn't be accurate, and could cause confusion in the future.

  • Annotating a PrimFunc with FuncStructInfo to represent the PrimFunc itself (i.e. impure function, mutates arguments) would be accurate, but insufficient for call_tir to propagate shapes, as input/output shapes are mixed.

  • Would be useful to have a purity annotations for each parameter, dividing arguments into read-only, output, and mutate-in-place. This would allow a PrimFunc to be accurately annotated, and would be sufficient for call_tir to identify outputs for shape propagation.


tir_vars = None
if len(unbound_tir_vars) > 0:
Expand Down