diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 9ff5bff5f1ff..568e05351aad 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -109,4 +109,4 @@ from . import analysis from . import stmt_functor from .build import build -from .pipeline import get_pipeline +from .pipeline import get_tir_pipeline, get_default_tir_pipeline diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py index cd44ed881ba3..ee6280b74091 100644 --- a/python/tvm/tir/build.py +++ b/python/tvm/tir/build.py @@ -131,41 +131,51 @@ def build( assert isinstance(mod, tvm.IRModule) # Step 0: Determine the target in environment + # It's used to bind the PrimFunc without target attr to serve as a default target + target_to_bind = Target.current() if target is None else target + if target_to_bind is None: + target_to_bind = "llvm" + assert target_to_bind is not None + target_to_bind = Target.canon_target(target_to_bind) + + # Step 1: Determine the target to search for tir pipeline target = Target.current() if target is None else target if target is None: - target = "llvm" - assert target is not None - target = Target.canon_target(target) + for func in mod.functions.values(): + f_target = func.attrs.get("target", None) + if f_target is not None: + target = f_target + break + if target is not None: + target = Target.canon_target(target) - # Step 1: Determine the host + # Step 2: Determine the host target target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" if target is not None: if target.host is not None: target_host = target.host elif ndarray.device(target.kind.name, 0).device_type == ndarray.cpu(0).device_type: target_host = target - else: - for func in mod.functions.values(): - f_target = func.attrs.get("target", None) - if f_target is not None and f_target.host is not None: - target_host = f_target.host - assert target_host is not None target_host = Target.canon_target(target_host) - target = target.with_host(target_host) + target_to_bind = target_to_bind.with_host(target_host) - # Step 2: Bind the target to the input module - mod = tvm.tir.transform.BindTarget(target)(mod) + # Step 3: Bind the target to the input module + mod = tvm.tir.transform.BindTarget(target_to_bind)(mod) - # Step 3: Apply the pipeline + # Step 4: Apply the tir pipeline if pipeline is not None: + # custom pipeline if isinstance(pipeline, str): - pipeline = tvm.tir.get_pipeline(pipeline) - mod = pipeline(mod) + pipeline = tvm.tir.get_tir_pipeline(pipeline) + else: + # default pipeline depends on the target + pipeline = tvm.tir.get_default_tir_pipeline(target) + mod = pipeline(mod) - # Step 4: Get host and device modules + # Step 5: Get host and device modules host_mod, device_mod_dict = split_host_device_mods(mod) - # Step 5: Apply finalization passes + # Step 6: Apply finalization passes host_mod = tvm.tir.pipeline.finalize_host_passes()(host_mod) device_mod_dict = { target: tvm.tir.pipeline.finalize_device_passes()(device_mod) diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py index 0b6d622c90e1..b7141bae30de 100644 --- a/python/tvm/tir/pipeline.py +++ b/python/tvm/tir/pipeline.py @@ -160,7 +160,7 @@ def finalize_device_passes(): # pylint: disable=unused-argument } -def get_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass: +def get_tir_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass: """Get pre-build pipeline by name Parameters @@ -173,3 +173,10 @@ def get_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass: f"Unknown pre-built pipeline {name}," f"candidates are {list(PIPELINE_MAP.keys())}" ) return PIPELINE_MAP[name](**kwargs) + + +def get_default_tir_pipeline( + target: tvm.target.Target, # pylint: disable=unused-argument +) -> tvm.transform.Pass: + """Get the default TIR pipeline for the given target.""" + return default_tir_pipeline()