From cae341bb1b653032bc8c03904d9ce6b7598bd1fb Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 20 Feb 2025 20:24:57 +0000 Subject: [PATCH 1/3] fix --- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/build.py | 51 ++++++++++++++++++++++---------------- python/tvm/tir/pipeline.py | 7 +++++- 3 files changed, 37 insertions(+), 23 deletions(-) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 9ff5bff5f1ff..568e05351aad 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -109,4 +109,4 @@ from . import analysis from . import stmt_functor from .build import build -from .pipeline import get_pipeline +from .pipeline import get_tir_pipeline, get_default_tir_pipeline diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py index cd44ed881ba3..b7882a65b3bb 100644 --- a/python/tvm/tir/build.py +++ b/python/tvm/tir/build.py @@ -131,41 +131,50 @@ def build( assert isinstance(mod, tvm.IRModule) # Step 0: Determine the target in environment + # It's used to bind the PrimFunc without target attr to serve as a default target + target_to_bind = Target.current() if target is None else target + if target_to_bind is None: + target_to_bind = "llvm" + assert target_to_bind is not None + target_to_bind = Target.canon_target(target_to_bind) + + # Step 1: Determine the target to search for tir pipeline target = Target.current() if target is None else target if target is None: - target = "llvm" + for func in mod.functions.values(): + f_target = func.attrs.get("target", None) + if f_target is not None: + target = f_target + break assert target is not None target = Target.canon_target(target) - # Step 1: Determine the host + # Step 2: Determine the host target target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - if target is not None: - if target.host is not None: - target_host = target.host - elif ndarray.device(target.kind.name, 0).device_type == ndarray.cpu(0).device_type: - target_host = target - else: - for func in mod.functions.values(): - f_target = func.attrs.get("target", None) - if f_target is not None and f_target.host is not None: - target_host = f_target.host - assert target_host is not None + if target.host is not None: + target_host = target.host + elif ndarray.device(target.kind.name, 0).device_type == ndarray.cpu(0).device_type: + target_host = target target_host = Target.canon_target(target_host) - target = target.with_host(target_host) + target_to_bind = target_to_bind.with_host(target_host) - # Step 2: Bind the target to the input module - mod = tvm.tir.transform.BindTarget(target)(mod) + # Step 3: Bind the target to the input module + mod = tvm.tir.transform.BindTarget(target_to_bind)(mod) - # Step 3: Apply the pipeline + # Step 4: Apply the tir pipeline if pipeline is not None: + # custom pipeline if isinstance(pipeline, str): - pipeline = tvm.tir.get_pipeline(pipeline) - mod = pipeline(mod) + pipeline = tvm.tir.get_tir_pipeline(pipeline) + else: + # default pipeline depends on the target + pipeline = tvm.tir.get_default_tir_pipeline(target) + mod = pipeline(mod) - # Step 4: Get host and device modules + # Step 5: Get host and device modules host_mod, device_mod_dict = split_host_device_mods(mod) - # Step 5: Apply finalization passes + # Step 6: Apply finalization passes host_mod = tvm.tir.pipeline.finalize_host_passes()(host_mod) device_mod_dict = { target: tvm.tir.pipeline.finalize_device_passes()(device_mod) diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py index 0b6d622c90e1..c2d7405e1f92 100644 --- a/python/tvm/tir/pipeline.py +++ b/python/tvm/tir/pipeline.py @@ -160,7 +160,7 @@ def finalize_device_passes(): # pylint: disable=unused-argument } -def get_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass: +def get_tir_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass: """Get pre-build pipeline by name Parameters @@ -173,3 +173,8 @@ def get_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass: f"Unknown pre-built pipeline {name}," f"candidates are {list(PIPELINE_MAP.keys())}" ) return PIPELINE_MAP[name](**kwargs) + + +def get_default_tir_pipeline(target: tvm.target.Target) -> tvm.transform.Pass: + """Get the default TIR pipeline for the given target.""" + return default_tir_pipeline() From a7feea88f5cb13e73c159213252a5a21799ebfb2 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 20 Feb 2025 20:36:11 +0000 Subject: [PATCH 2/3] fix --- python/tvm/tir/pipeline.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py index c2d7405e1f92..b7141bae30de 100644 --- a/python/tvm/tir/pipeline.py +++ b/python/tvm/tir/pipeline.py @@ -175,6 +175,8 @@ def get_tir_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass: return PIPELINE_MAP[name](**kwargs) -def get_default_tir_pipeline(target: tvm.target.Target) -> tvm.transform.Pass: +def get_default_tir_pipeline( + target: tvm.target.Target, # pylint: disable=unused-argument +) -> tvm.transform.Pass: """Get the default TIR pipeline for the given target.""" return default_tir_pipeline() From 7ea9b49c43a34c7a15927f262b9d9559d0a37249 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 20 Feb 2025 21:04:13 +0000 Subject: [PATCH 3/3] fix --- python/tvm/tir/build.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py index b7882a65b3bb..ee6280b74091 100644 --- a/python/tvm/tir/build.py +++ b/python/tvm/tir/build.py @@ -146,15 +146,16 @@ def build( if f_target is not None: target = f_target break - assert target is not None - target = Target.canon_target(target) + if target is not None: + target = Target.canon_target(target) # Step 2: Determine the host target target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - if target.host is not None: - target_host = target.host - elif ndarray.device(target.kind.name, 0).device_type == ndarray.cpu(0).device_type: - target_host = target + if target is not None: + if target.host is not None: + target_host = target.host + elif ndarray.device(target.kind.name, 0).device_type == ndarray.cpu(0).device_type: + target_host = target target_host = Target.canon_target(target_host) target_to_bind = target_to_bind.with_host(target_host)