From 532738ccedde10d045ec3a23d165934ea56b365e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 27 Jun 2022 09:38:02 -0500 Subject: [PATCH] [TIR] Improved error message if tir.Schedule passed to lower/build Previously, if a TIR Schedule is passed to `tvm.lower`, the error message is returned `ValueError: ('Expected input to be an IRModule, PrimFunc or Schedule, but got, ', )`. This can cause user confusion, as the expected class name in the error message does not differentiate between between a `tvm.te.Schedule` and a `tvm.tir.Schedule`. Updated error message to explicitly state that this should be a `te.Schedule`. --- python/tvm/driver/build_module.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 24e80686850d..47a922f7a3b1 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -23,13 +23,14 @@ import tvm.tir +from tvm import te + from tvm.runtime import Module from tvm.runtime import ndarray from tvm.ir import container from tvm.tir import PrimFunc from tvm.ir.module import IRModule from tvm.te import tensor -from tvm.te import schedule from tvm.target import Target from tvm.tir.buffer import Buffer from tvm.tir.expr import Var @@ -62,7 +63,7 @@ def get_binds(args, compact=False, binds=None): def schedule_to_module( - sch: schedule.Schedule, + sch: te.Schedule, args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, name: str = "main", binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, @@ -91,7 +92,7 @@ def schedule_to_module( def lower( - inp: Union[schedule.Schedule, PrimFunc, IRModule], + inp: Union[te.Schedule, PrimFunc, IRModule], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, name: str = "main", binds: Optional[Mapping[tensor.Tensor, Buffer]] = None, @@ -129,13 +130,15 @@ def lower( return ffi.lower_module(inp, simple_mode) if isinstance(inp, PrimFunc): return ffi.lower_primfunc(inp, name, simple_mode) - if isinstance(inp, schedule.Schedule): + if isinstance(inp, te.Schedule): return ffi.lower_schedule(inp, args, name, binds, simple_mode) - raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp)) + raise ValueError( + f"Expected input to be an IRModule, PrimFunc or te.Schedule, but got {type(inp)}" + ) def build( - inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]], + inputs: Union[te.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]], args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None, target: Optional[Union[str, Target]] = None, target_host: Optional[Union[str, Target]] = None, @@ -219,7 +222,7 @@ def build( ---- See the note on :any:`tvm.target` on target string format. """ - if isinstance(inputs, schedule.Schedule): + if isinstance(inputs, te.Schedule): if args is None: raise ValueError("args must be given for build from schedule") input_mod = lower(inputs, args, name=name, binds=binds) @@ -234,7 +237,8 @@ def build( input_mod = lower(inputs) elif not isinstance(inputs, (dict, container.Map)): raise ValueError( - f"Inputs must be Schedule, IRModule or dict of target to IRModule, " + f"Inputs must be te.Schedule, IRModule, PrimFunc, " + f"or dict of target to IRModule, " f"but got {type(inputs)}." )