Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@

import tvm.tir

from tvm import te

from tvm.runtime import Module
from tvm.runtime import ndarray
from tvm.ir import container
from tvm.tir import PrimFunc
from tvm.ir.module import IRModule
from tvm.te import tensor
from tvm.te import schedule
from tvm.target import Target
from tvm.tir.buffer import Buffer
from tvm.tir.expr import Var
Expand Down Expand Up @@ -62,7 +63,7 @@ def get_binds(args, compact=False, binds=None):


def schedule_to_module(
sch: schedule.Schedule,
sch: te.Schedule,
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
Expand Down Expand Up @@ -91,7 +92,7 @@ def schedule_to_module(


def lower(
inp: Union[schedule.Schedule, PrimFunc, IRModule],
inp: Union[te.Schedule, PrimFunc, IRModule],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
Expand Down Expand Up @@ -129,13 +130,15 @@ def lower(
return ffi.lower_module(inp, simple_mode)
if isinstance(inp, PrimFunc):
return ffi.lower_primfunc(inp, name, simple_mode)
if isinstance(inp, schedule.Schedule):
if isinstance(inp, te.Schedule):
return ffi.lower_schedule(inp, args, name, binds, simple_mode)
raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp))
raise ValueError(
f"Expected input to be an IRModule, PrimFunc or te.Schedule, but got {type(inp)}"
)


def build(
inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]],
inputs: Union[te.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
target: Optional[Union[str, Target]] = None,
target_host: Optional[Union[str, Target]] = None,
Expand Down Expand Up @@ -219,7 +222,7 @@ def build(
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(inputs, schedule.Schedule):
if isinstance(inputs, te.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
input_mod = lower(inputs, args, name=name, binds=binds)
Expand All @@ -234,7 +237,8 @@ def build(
input_mod = lower(inputs)
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError(
f"Inputs must be Schedule, IRModule or dict of target to IRModule, "
f"Inputs must be te.Schedule, IRModule, PrimFunc, "
f"or dict of target to IRModule, "
f"but got {type(inputs)}."
)

Expand Down