From 0e41b27119e5a32feee43a28070edb6d87adfeb3 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Tue, 18 Feb 2025 15:48:22 +0000 Subject: [PATCH 01/17] dev --- python/tvm/driver/build_module.py | 525 +++++++++++++++++++++++--- python/tvm/tir/transform/transform.py | 33 ++ 2 files changed, 496 insertions(+), 62 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 94006111ffa2..9b8e01564fc8 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -17,52 +17,469 @@ # pylint: disable=invalid-name """The build utils in python.""" -from typing import Union, Optional - - -import tvm.tir - - +from typing import Union, Optional, Dict, List, Tuple +import enum +import tvm +from tvm import tir, ir from tvm.runtime import ndarray from tvm.tir import PrimFunc from tvm.ir.module import IRModule from tvm.target import Target -from tvm.driver import _ffi_api as _driver_ffi +from tvm._ffi.runtime_ctypes import Device -from . import _ffi_api as ffi +def create_pass_list(disable_loop_partition: bool): + """Create a list of passes based on pass context configurations. -def lower( - inp: Union[PrimFunc, IRModule], - name: str = "main", - simple_mode: bool = False, -) -> IRModule: - """Lowering step before build into target. + Parameters + ---------- + disable_loop_partition : bool + Whether to disable loop partition pass. + + Returns + ------- + List[tvm.tir.transform.Pass] + List of passes to run. + """ + pass_ctx = tvm.transform.PassContext.current() + config = pass_ctx.config + # Retrieve configuration flags. + disable_vectorize = config.get("tir.disable_vectorize", False) + disable_storage_rewrite = config.get("tir.disable_storage_rewrite", False) + instrument_bound_checkers = config.get("tir.instrument_bound_checkers", False) + disable_cse_tir = config.get("tir.disable_cse_tir", False) + enable_equiv_terms_in_cse_tir = config.get("tir.enable_equiv_terms_in_cse_tir", False) + ptx_ldg32 = config.get("tir.ptx_ldg32", False) + instrument_lwp = config.get("tir.instrument_lwp", False) + add_lower_pass = config.get("tir.add_lower_pass", []) + + # Group user passes by phase (phases 0, 1, 2, and 3 where phase>=3 goes to 3) + user_passes = {0: [], 1: [], 2: [], 3: []} + for phase, p in add_lower_pass: + if not isinstance(phase, int) or phase < 0: + raise ValueError( + f"Phase number must be a non-negative integer, got {phase} of type {type(phase)}" + ) + user_passes[phase if phase < 3 else 3].append(p) + + # Construct phase-specific passes. + phase0 = user_passes[0] + + phase1 = [ + tir.transform.InjectPrefetch(), + tir.transform.TextureFlatten(), + tir.transform.StorageFlatten(64, instrument_bound_checkers), + tir.transform.LowerCrossThreadReduction(), + tir.transform.LowerInitBlock(), + tir.transform.PlanAndUpdateBufferAllocationLocation(), + tir.transform.ConvertBlocksToOpaque(), + tir.transform.LiftThreadBinding(), + tir.transform.ManifestSharedMemoryLocalStage(), + tir.transform.CompactBufferAllocation(), + tir.transform.LowerAutoCopy(), + tir.transform.UnifyThreadBinding(), + tir.transform.LowerMatchBuffer(), + tir.transform.Simplify(), + tir.transform.InjectPermutedLayout(), + tir.transform.Simplify(), + tir.transform.InjectSoftwarePipeline(), + tir.transform.TransformMmaBufferLayout(), + tir.transform.LowerOpaqueBlock(), + tir.transform.FlattenBuffer(), + tir.transform.BF16ComputeLegalize(), + tir.transform.NarrowDataType(32), + tir.transform.Simplify(), + ] + user_passes[1] + + phase2 = [] + if not disable_loop_partition: + phase2.append(tir.transform.LoopPartition()) + phase2.extend( + [ + tir.transform.VectorizeLoop(not disable_vectorize), + tir.transform.InjectVirtualThread(), + tir.transform.InjectDoubleBuffer(), + ] + ) + if not disable_storage_rewrite: + phase2.append(tir.transform.StorageRewrite()) + if config.get("tir.use_async_copy", False): + phase2.append(tir.transform.LowerAsyncDMA()) + phase2.extend( + [ + tir.transform.HoistIfThenElse(), + tir.transform.UnrollLoop(), + ] + ) + phase2 += user_passes[2] + + phase3 = [ + tir.transform.RenormalizeSplitPattern(), + tir.transform.Simplify(), + tir.transform.RemoveNoOp(), + tir.transform.RewriteUnsafeSelect(), + ] + user_passes[3] + + # Additional passes based on configuration. + extras = [] + if instrument_bound_checkers: + extras.append(tir.transform.InstrumentBoundCheckers()) + if ptx_ldg32: + extras.append(tir.transform.InjectPTXLDG32(True)) + extras.append( + tir.transform.CommonSubexprElimTIR(not disable_cse_tir, enable_equiv_terms_in_cse_tir) + ) + if instrument_lwp: + extras.append(tir.transform.InstrumentProfileIntrinsics()) + + return phase0 + phase1 + phase2 + phase3 + extras + + +def lower_module(inp: IRModule, simple_mode: bool = False) -> IRModule: + """Lowering step before building the target. Parameters ---------- - inp : Union[tvm.tir.PrimFunc, IRModule] - The TE schedule or TensorIR PrimFunc/IRModule to be built + inp : IRModule + The IRModule to be lowered. + simple_mode : bool + Whether to output only a simple, compact statement. + + Returns + ------- + IRModule + The lowered IRModule. + """ + return tvm.ir.transform.Sequential(create_pass_list(simple_mode))(inp) + + +def lower_primfunc(inp: PrimFunc, name: str = "main", simple_mode: bool = False) -> IRModule: + """Lowering step before building the target for a PrimFunc. + Parameters + ---------- + inp : PrimFunc + The PrimFunc to be lowered. name : str - The name of the result function. + The name of the resulting function. + simple_mode : bool + Whether to output only a simple, compact statement. + Returns + ------- + IRModule + The lowered IRModule. + """ + pass_ctx = tvm.ir.transform.PassContext.current() + f = inp.with_attr("global_symbol", name) + if pass_ctx.config.get("tir.noalias", True): + f = f.with_attr("tir.noalias", True) + mod = tvm.ir.IRModule({tvm.ir.GlobalVar(name): f}) + return tvm.ir.transform.Sequential(create_pass_list(simple_mode))(mod) + + +def lower( + inp: Union[PrimFunc, IRModule], name: str = "main", simple_mode: bool = False +) -> IRModule: + """Lowering step before building the target. + + Parameters + ---------- + inp : Union[PrimFunc, IRModule] + The PrimFunc or IRModule to be lowered. + name : str + The name of the resulting function (if applicable). simple_mode : bool - Whether only output simple and compact statement, this will skip - LoopPartition, api wrapper generation and Unrolling. + Whether to output only a simple, compact statement. Returns ------- - m : IRModule - The result IRModule + IRModule + The lowered IRModule. """ if isinstance(inp, IRModule): - return ffi.lower_module(inp, simple_mode) + return lower_module(inp, simple_mode) if isinstance(inp, PrimFunc): - return ffi.lower_primfunc(inp, name, simple_mode) - raise ValueError( - f"Expected input to be an IRModule, PrimFunc or te.Schedule, but got {type(inp)}" + return lower_primfunc(inp, name, simple_mode) + raise ValueError(f"Expected input to be IRModule or PrimFunc, but got {type(inp)}") + + +def check_and_update_host_consistency(targets: dict, host): + """ + Check and update the host field of the given legacy heterogeneous targets + for legacy target API compatibility. + + Parameters + ---------- + targets : dict + Dictionary mapping Target objects to IRModule objects. + host : Target + The target host to be updated. + """ + for tgt in list(targets): + if getattr(tgt, "host", None) is None: + tgt.host = host + + +def mixed_module_pass_manager(target: Target) -> tvm.ir.transform.Sequential: + """ + Constructs a Sequential transformation pass pipeline for a mixed module. + + Parameters + ---------- + target : Target + The target device for which the module is intended. + + Returns + ------- + tvm.ir.transform.Sequential + A sequential pass pipeline for the mixed module. + """ + pass_ctx = tvm.ir.transform.PassContext.current() + mixed_pass_list = [ + # Bind the target first so that target-specific attributes are available. + tir.transform.BindTarget(target), + tir.transform.FP8ComputeLegalize(), + # VerifyVTCMLimit must occur before LowerVtcmAlloc. + tir.transform.VerifyVTCMLimit(target), + tir.transform.LowerVtcmAlloc(), + tir.transform.VerifyMemory(), + tir.transform.AnnotateEntryFunc(), + ] + if pass_ctx.config.get("tir.detect_global_barrier", False): + mixed_pass_list.append(tir.transform.ThreadSync("global")) + mixed_pass_list.extend( + [ + tir.transform.ThreadSync("shared"), + tir.transform.ThreadSync("shared.dyn"), + tir.transform.ThreadSync("warp"), + tir.transform.InferFragment(), + tir.transform.LowerThreadAllreduce(), + ] + ) + if pass_ctx.config.get("tir.use_async_copy", False): + mixed_pass_list.append(tir.transform.InjectPTXAsyncCopy()) + if pass_ctx.config.get("tir.ptx_ldg32", False): + mixed_pass_list.append(tir.transform.InjectPTXLDG32()) + mixed_pass_list.extend( + [ + tir.transform.AnnotateDeviceRegions(), + tir.transform.SplitHostDevice(), + # MergeSharedMemoryAllocations must follow SplitHostDevice. + tir.transform.MergeSharedMemoryAllocations(), + tir.transform.MakePackedAPI(), + tir.transform.FP8StorageLegalize(), + tir.transform.BF16StorageLegalize(), + tir.transform.LowerDeviceKernelLaunch(), + ] ) + return tvm.ir.transform.Sequential(mixed_pass_list) + + +class CallConv(enum.IntEnum): + """ + Enum representing different calling conventions. + Corresponds to the C++ tvm::ir::CallingConv enum. + """ + + kDefault = 0 + kCPackedFunc = 1 + kDeviceKernelLaunch = 2 + + +def host_module_pass_manager(target_host: Target) -> tvm.ir.transform.Sequential: + """ + Build a sequential pass pipeline for lowering the host part of a mixed module. + + Parameters + ---------- + target_host : Target + The host target for which to lower the module. + + Returns + ------- + tvm.ir.transform.Sequential + A sequential pass pipeline for host-specific transformations. + """ + host_pass_list = [ + # Filter out device kernel launches. + tir.transform.Filter( + lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) + != int(CallConv.kDeviceKernelLaunch) + ), + tir.transform.BindTarget(target_host), + tir.transform.LowerTVMBuiltin(), + tir.transform.LowerCustomDatatypes(), + tir.transform.LowerIntrin(), + tir.transform.LowerDeviceStorageAccessInfo(), + tir.transform.CombineContextCall(), + ] + return tvm.ir.transform.Sequential(host_pass_list) + + +def device_module_pass_manager(target: Target) -> tvm.ir.transform.Sequential: + """ + Build a sequential pass pipeline for lowering the device part of a mixed module. + + Parameters + ---------- + target : Target + The target for device-specific transformations. + + Returns + ------- + tvm.ir.transform.Sequential + A sequential pass pipeline for device-specific transformations. + """ + device_pass_list = [ + # Select only device kernel launches. + tir.transform.Filter( + lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) + == int(CallConv.kDeviceKernelLaunch) + ), + tir.transform.BindTarget(target), + tir.transform.LowerWarpMemory(), + tir.transform.Simplify(), + tir.transform.LowerCustomDatatypes(), + tir.transform.LowerDeviceStorageAccessInfo(), + tir.transform.LowerIntrin(), + ] + return tvm.ir.transform.Sequential(device_pass_list) + + +def split_mixed_module( + mod_mixed: IRModule, target_arg: Target, target_host_arg: Target +) -> Tuple[IRModule, IRModule]: + """ + Split a mixed module containing both device and host parts into separate modules, + applying appropriate transformations on each. + + Parameters + ---------- + mod_mixed : IRModule + The input module containing both device and host code. + target_arg : Target + The target for device-specific transformations. + target_host_arg : Target + The host target for lowering. + + Returns + ------- + Tuple[IRModule, IRModule] + (host module, device module) + """ + target, target_host = target_arg, target_host_arg + if getattr(target, "host", None) is None: + target.host = target_host + if mod_mixed is None: + raise ValueError("Module must be defined") + + mod_mixed = mixed_module_pass_manager(target)(mod_mixed) + host_mod = host_module_pass_manager(target_host)(mod_mixed) + device_mod = device_module_pass_manager(target)(mod_mixed) + + # Warn if target is GPU but no device code was generated. + if "gpu" in target.keys and len(device_mod.functions) == 0: + print( + f"Warning: Specified target {target} but cannot find device code. Did you forget to bind?" + ) + + return host_mod, device_mod + + +def default_target_host(target: Target) -> Target: + """ + Determine the default target host for a given target. + """ + if target is not None and target.device_type == Device.kDLCPU: + return target + # In practice, llvm_enabled should be determined dynamically. + llvm_enabled = True + return Target("llvm") if llvm_enabled else Target("stackvm") + + +def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module: + """ + Build a runtime module from an IRModule and a Target. + + If the "tir.disable_assert" flag is set in the pass context, + the SkipAssert transformation is applied. + + Parameters + ---------- + mod : IRModule + The input IRModule. + target : Target + The target for which to build the module. + + Returns + ------- + tvm.runtime.Module + The built runtime module. + """ + if tvm.ir.transform.PassContext.current().config.get("tir.disable_assert", False): + mod = tvm.tir.transform.SkipAssert()(mod) + build_f_name = "target.build." + target.kind.name + bf = tvm.get_global_func(build_f_name) + if bf is None: + raise ValueError(f"{build_f_name} is not enabled") + return bf(mod, target) + + +def tir_to_runtime(inputs: Dict[Target, IRModule], target_host: Target): + """ + Convert a collection of TIR IRModules (keyed by Target) into a single runtime Module. + + Parameters + ---------- + inputs : dict + Mapping from Target to IRModule. + target_host : Target + The initial host target. + + Returns + ------- + tvm.runtime.Module + The final runtime module. + """ + if not inputs: + raise ValueError("TIRToRuntime expects at least one IRModule as input.") + + check_and_update_host_consistency(inputs, target_host) + if not target_host: + for tgt in inputs: + if tgt.get_target_device_type() == Device.kDLCPU: + target_host = tgt + break + if not target_host: + target_host = default_target_host(target_host) + check_and_update_host_consistency(inputs, target_host) + + first_module = next(iter(inputs.values())) + mhost_all = ir.IRModule({}, attrs=first_module.attrs) + if mhost_all is None: + raise ValueError("The host module must be defined") + + device_modules = [] + for tgt, ir_module in inputs.items(): + if ir_module: + host_mod, device_mod = split_mixed_module(ir_module, tgt, target_host) + overrides_host_target = ( + tgt.get_target_device_type() == target_host.get_target_device_type() + ) + non_host_target_kind = tgt.kind != target_host.kind + if overrides_host_target and non_host_target_kind: + device_modules.append(codegen_build(host_mod, tgt)) + else: + mhost_all.update(host_mod) + if len(device_mod.functions) != 0: + device_modules.append(codegen_build(device_mod, tgt)) + + mhost = codegen_build(mhost_all, target_host) + for dev_mod in device_modules: + if dev_mod is not None: + mhost.import_module(dev_mod) + return mhost def build( @@ -70,35 +487,29 @@ def build( target: Optional[Union[str, Target]] = None, name: str = "main", ): - """Build a function with arguments as signature. Code will be generated - for devices coupled with target information. + """ + Build a function with a signature, generating code for devices + coupled with target information. Parameters ---------- - input : Union[tvm.tir.PrimFunc, IRModule] - The input to be built - + inputs : Union[PrimFunc, IRModule] + The input to be built. target : Optional[Union[str, Target]] - The target and option of the compilation. - + The target for compilation. name : str - The name of result function. + The name of the result function. Returns ------- - ret : tvm.module - A module that combines both host and device code. - - Note - ---- - See the note on :any:`tvm.target` on target string format. + tvm.runtime.Module + A module combining both host and device code. """ if isinstance(inputs, PrimFunc): input_mod = lower(inputs, name=name) elif isinstance(inputs, tvm.IRModule): - assert ( - len(inputs.get_global_vars()) > 0 - ), "Expected a non-empty IRModule, but the IRModule contained no functions." + if not inputs.get_global_vars(): + raise ValueError("Expected a non-empty IRModule.") input_mod = lower(inputs) else: raise ValueError("Inputs must be IRModule or PrimFunc") @@ -107,41 +518,31 @@ def build( if target is None and isinstance(input_mod, tvm.IRModule): target_mod = {} for gvar, func in input_mod.functions.items(): - tgt = func.attrs["target"] if "target" in func.attrs else "llvm" - if tgt not in target_mod: - target_mod[tgt] = {} - target_mod[tgt][gvar] = func - - target_input_mod = {} - for tgt in target_mod.keys(): - tir_mod = tvm.IRModule(target_mod[tgt]) - tir_mod = tir_mod.with_attrs(input_mod.attrs) - target_input_mod[tgt] = tir_mod + tgt = func.attrs.get("target", "llvm") + target_mod.setdefault(tgt, {})[gvar] = func + target_input_mod = { + tgt: tvm.IRModule(funcs).with_attrs(input_mod.attrs) + for tgt, funcs in target_mod.items() + } else: target_input_mod = {target: input_mod} - # Because modules can be created from a variety of sources, we annotate them - # with the relevant attributes here to ensure they propagate annotated_mods = {} for tgt, mod in target_input_mod.items(): if not isinstance(tgt, (str, Target)): - raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") + raise ValueError("The key of inputs must be str or Target.") if not isinstance(mod, tvm.IRModule): - raise ValueError("inputs must be IRModule, " "or dict of str to IRModule.") + raise ValueError("inputs must be IRModule, or dict of str to IRModule.") annotated_mods[tgt] = mod annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods) if not target_host: for tar, mod in annotated_mods.items(): - device_type = ndarray.device(tar.kind.name, 0).device_type - if device_type == ndarray.cpu(0).device_type: + if ndarray.device(tar.kind.name, 0).device_type == ndarray.cpu(0).device_type: target_host = tar break if not target_host: target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host) - - return rt_mod_host + return tir_to_runtime(annotated_mods, target_host) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index b08659e1c712..fb5a1ba79669 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -1200,3 +1200,36 @@ def UseAssumeToReduceBranches(): The result pass """ return _ffi_api.UseAssumeToReduceBranches() # type: ignore + + +def LowerAsyncDMA(): + """Lower async DMA to DMA. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerAsyncDMA() # type: ignore + + +def InjectPTXLDG32(enable_inject_ptx_intrin: bool = True): + """Inject ptx.ldg.32 intrinsics. + + Parameters + ---------- + enable_inject_ptx_intrin : bool + If True, inject ptx.ldg.32 intrinsics. + """ + return _ffi_api.InjectPTXLDG32(enable_inject_ptx_intrin) # type: ignore + + +def LowerVtcmAlloc(): + """Lower vtcm allocation. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerVtcmAlloc() # type: ignore From ab99b995e072ee9bd15126cfe7d07d12b03df155 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Tue, 18 Feb 2025 16:22:51 +0000 Subject: [PATCH 02/17] dev --- python/tvm/driver/build_module.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 9b8e01564fc8..cd8f6f3c1e9a 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name """The build utils in python.""" -from typing import Union, Optional, Dict, List, Tuple +from typing import Union, Optional, Dict, Tuple import enum import tvm from tvm import tir, ir @@ -381,7 +381,8 @@ def split_mixed_module( # Warn if target is GPU but no device code was generated. if "gpu" in target.keys and len(device_mod.functions) == 0: print( - f"Warning: Specified target {target} but cannot find device code. Did you forget to bind?" + f"Warning: Specified target {target} but cannot find device code. " + "Did you forget to bind?" ) return host_mod, device_mod From cfffe1d274f70c55be0c7a55c3ca7ea1efd433a9 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Tue, 18 Feb 2025 19:34:39 +0000 Subject: [PATCH 03/17] dev --- tests/python/codegen/test_target_codegen_cuda.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index ae3173a14dee..b3cad9acd38e 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -766,6 +766,7 @@ def func3(A: T.Buffer((4, 4), "float32")) -> None: tvm.build(mod, target="cuda") +@tvm.testing.requires_cuda def test_invalid_reinterpret(): @T.prim_func def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None: From 0e6bcfc4203e8572ee882a989186ee29b9bc5c43 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Tue, 18 Feb 2025 21:09:14 +0000 Subject: [PATCH 04/17] dev --- python/tvm/driver/build_module.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index cd8f6f3c1e9a..5dcfe64b71c8 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -44,13 +44,13 @@ def create_pass_list(disable_loop_partition: bool): pass_ctx = tvm.transform.PassContext.current() config = pass_ctx.config # Retrieve configuration flags. - disable_vectorize = config.get("tir.disable_vectorize", False) - disable_storage_rewrite = config.get("tir.disable_storage_rewrite", False) - instrument_bound_checkers = config.get("tir.instrument_bound_checkers", False) - disable_cse_tir = config.get("tir.disable_cse_tir", False) - enable_equiv_terms_in_cse_tir = config.get("tir.enable_equiv_terms_in_cse_tir", False) - ptx_ldg32 = config.get("tir.ptx_ldg32", False) - instrument_lwp = config.get("tir.instrument_lwp", False) + disable_vectorize = bool(config.get("tir.disable_vectorize", False)) + disable_storage_rewrite = bool(config.get("tir.disable_storage_rewrite", False)) + instrument_bound_checkers = bool(config.get("tir.instrument_bound_checkers", False)) + disable_cse_tir = bool(config.get("tir.disable_cse_tir", False)) + enable_equiv_terms_in_cse_tir = bool(config.get("tir.enable_equiv_terms_in_cse_tir", False)) + ptx_ldg32 = bool(config.get("tir.ptx_ldg32", False)) + instrument_lwp = bool(config.get("tir.instrument_lwp", False)) add_lower_pass = config.get("tir.add_lower_pass", []) # Group user passes by phase (phases 0, 1, 2, and 3 where phase>=3 goes to 3) From fa0f312f6bcd5741120f5a052620e640029cff1e Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Tue, 18 Feb 2025 23:04:26 +0000 Subject: [PATCH 05/17] dev --- include/tvm/driver/driver_api.h | 124 ------ python/tvm/driver/build_module.py | 3 + src/driver/driver_api.cc | 595 ------------------------- src/driver/internal_driver_api.h | 48 -- src/ir/transform.cc | 19 + src/relax/backend/vm/codegen_vm.cc | 1 - src/relax/backend/vm/codegen_vm_tir.cc | 1 - src/relax/transform/bind_params.cc | 1 - src/relax/transform/fold_constant.cc | 6 +- src/tir/transforms/primfunc_utils.cc | 1 - 10 files changed, 25 insertions(+), 774 deletions(-) delete mode 100644 include/tvm/driver/driver_api.h delete mode 100644 src/driver/driver_api.cc delete mode 100644 src/driver/internal_driver_api.h diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h deleted file mode 100644 index 39444d1629fe..000000000000 --- a/include/tvm/driver/driver_api.h +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/driver/driver_api.h - * \brief Compiler driver APIs to drive the compilation. - * - * This module provides end-to-end utils to drive the compilation process. - * We adopt the term "compiler driver" in common compiler infrastructures. - * Note that a compiler driver is different from "runtime drivers". - * Most of runtime related code are defined in the runtime folder instead. - */ -#ifndef TVM_DRIVER_DRIVER_API_H_ -#define TVM_DRIVER_DRIVER_API_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace tvm { -using tvm::transform::Pass; - -/*! - * \brief Configures and returns the composite Pass for the fused module (pre split) that contains - * device and host code. - * \param mixed_mod The original mixed module. - * \param target The device Target. - * \return The composite Pass for the fused module. -// */ -TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target); - -/*! - * \brief Configures and returns the composite Pass for the device Target after device/host from - * mixed module. - * \param mixed_mod The optimized mixed module. - * \param target The device Target. - * \return The composite Pass for the device module. - */ -TVM_DLL transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target); - -/*! - * \brief Configures and returns the composite Pass for the host Target after device/host from mixed - * module. - * \param mixed_mod The optimized mixed module. - * \param target_host The host Target. - * \return The composite Pass for the host module. - */ -TVM_DLL transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host); - -/*! - * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) - * \param mod The IRmodule to lower - * \param simple_mode Disables the loop partition pass. Defaults to false. - * \return The result module. - */ -TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false); - -/*! - * \brief Lower a primfunc and name (convert to IRModule, and optimize it with the pass list - * defined in CreatePassList) - * \param func The PrimFunc to lower - * \param name The name of the lowered function. - * \param simple_mode Disables the loop partition pass. Defaults to false. - * \return The result module. - */ -TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, - bool simple_mode = false); - -/*! - * \brief Build a device and host module for a specific target from an IRModule. - * \param funcs The functions to be built. - * \param target The target device to build for. - * \param target_host The target for building host code. To use the default, pass Target() - * \return The built module. - */ -TVM_DLL runtime::Module build(const IRModule& funcs, const Target& target, - const Target& target_host); - -/*! - * \brief Build a device and host module for a specific target from a map - * contains target to IRModule. This function is used - * for heterogeneous build. - * \param input The map contains target to an IRModule. - * \param target_host The target for building host code. To use the default, - * pass Target(). - * \return The built module that contains code for different processors. - */ -TVM_DLL runtime::Module build(const Map& input, const Target& target_host); - -/*! - * \brief Build a device and host module for a specific target from a map - * contains target to IRModule. This function is used - * for heterogeneous build. - * \param input The map contains target string to an IRModule. - * \param target_host The target for building host code. To use the default, - * pass Target(). - * \return The built module that contains code for different processors. - */ -TVM_DLL runtime::Module build(const Map& input, const Target& target_host); -} // namespace tvm - -#endif // TVM_DRIVER_DRIVER_API_H_ diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 5dcfe64b71c8..b5f79d300c09 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -547,3 +547,6 @@ def build( annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) return tir_to_runtime(annotated_mods, target_host) + + +tvm.register_func("tir.build", build) diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc deleted file mode 100644 index 5b12f13d96a6..000000000000 --- a/src/driver/driver_api.cc +++ /dev/null @@ -1,595 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Compile executable modules. - * \file driver_api.cc - */ -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace tvm { - -// Register build pipeline related options -TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); - -// WARNING: May cause coherency issues resulting data miscompares -// Experimental feature that, when enabled by the runtime, bypasses the cache when using DMA. When -// bypassing the cache TVM must manage cache coherency in software. Software managed cache coherency -// can be tricky e.g. it is yet to be proven out in the Hexagon runtime. Hence the warning above and -// the "experimental" notation for this feature. -TVM_REGISTER_PASS_CONFIG_OPTION("tir.experimental_dma_bypass_cache", Bool); - -using tvm::Array; -using tvm::transform::Pass; - -bool LLVMEnabled() { - const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm"); - return pf != nullptr; -} - -/*! \return The default host target for a given device target */ -Target DefaultTargetHost(Target target) { - if (target.defined() && target->GetTargetDeviceType() == kDLCPU) { - return target; - } else { - if (LLVMEnabled()) { - return Target("llvm"); - } else { - return Target("stackvm"); - } - } -} - -void GetBinds(const Array& args, bool compact, - const std::unordered_map& binds, - Map* out_binds, Array* out_arg_list) { - *out_binds = binds; - - for (const ObjectRef& x : args) { - if (auto tensor_node = x.as()) { - te::Tensor x_ref = tensor_node.value(); - if (out_binds->find(x_ref) == out_binds->end()) { - tir::Buffer buf = tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, - x_ref->op->name, -1, 0, compact); - out_binds->Set(x_ref, buf); - out_arg_list->push_back(buf); - } else { - out_arg_list->push_back((*out_binds)[x_ref]); - } - } else if (x.as() || x.as()) { - out_arg_list->push_back(x); - } else { - LOG(FATAL) - << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var, " - << "but got a " << x->GetTypeKey(); - } - } -} - -void GetBinds(const Array& args, bool compact, - const std::unordered_map& binds, - Map* out_binds, Array* out_arg_list) { - Array ref_args; - for (ObjectRef x : args) { - ref_args.push_back(x); - } - GetBinds(ref_args, compact, binds, out_binds, out_arg_list); -} - -TVM_REGISTER_GLOBAL("driver.get_binds") - .set_body_typed([](const Array& args, bool compact, - const Map& binds) { - std::unordered_map c_binds; - // Check to make sure binds is not null before doing the conversion; - if (binds.get() != nullptr) { - for (auto kv : binds) { - c_binds.insert({kv.first, kv.second}); - } - } - Map out_binds; - Array out_arg_list; - GetBinds(args, compact, c_binds, &out_binds, &out_arg_list); - - // TVM object system doesn't have a pair object, so we'll put both ret values in an array - // and return that. - Array out_arr = {out_binds, out_arg_list}; - return out_arr; - }); - -Array CreatePassList(bool disable_loop_partition) { - transform::PassContext pass_ctx = transform::PassContext::Current(); - - bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); - bool disable_storage_rewrite = - pass_ctx->GetConfig("tir.disable_storage_rewrite", Bool(false)).value(); - bool instrument_bound_checkers = - pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); - bool disable_cse_tir = pass_ctx->GetConfig("tir.disable_cse_tir", Bool(false)).value(); - bool enable_equiv_terms_in_cse_tir = - pass_ctx->GetConfig("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value(); - - bool ptx_ldg32 = pass_ctx->GetConfig("tir.ptx_ldg32", Bool(false)).value(); - - // Get any user-added passes - Array> add_lower_pass = - pass_ctx->GetConfig>>("tir.add_lower_pass", Array>()) - .value(); - - bool instrument_lwp = pass_ctx->GetConfig("tir.instrument_lwp", Bool(false)).value(); - - Array user_lower_phase0 = Array(); - Array user_lower_phase1 = Array(); - Array user_lower_phase2 = Array(); - Array user_lower_phase3 = Array(); - - // phase passes is of the form - // [[phase_number, pass], [phase_number, pass]... ] - for (Array phase_pass : add_lower_pass) { - auto phase_num = phase_pass[0].as(); - ICHECK(phase_num) - << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer, " - << "but instead received " << phase_pass[0] << " with type " << phase_pass[0]->GetTypeKey(); - int phase_num_val = phase_num->value; - - CHECK_GE(phase_num_val, 0); - - auto pass = Downcast(phase_pass[1]); - // Copy the pass into the correct phase - if (phase_num_val == 0) { - user_lower_phase0.push_back(pass); - } else if (phase_num_val == 1) { - user_lower_phase1.push_back(pass); - } else if (phase_num_val == 2) { - user_lower_phase2.push_back(pass); - } else if (phase_num_val >= 3) { - user_lower_phase3.push_back(pass); - } - } - - // Construct the pass list, inserting the user provided passes at the end of the phase - - // PHASE 0 - Array pass_list = user_lower_phase0; - - // PHASE 1 - pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back(tir::transform::TextureFlatten()); - pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); - pass_list.push_back(tir::transform::LowerCrossThreadReduction()); - pass_list.push_back(tir::transform::LowerInitBlock()); - pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); - pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); - pass_list.push_back(tir::transform::LiftThreadBinding()); - pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); - pass_list.push_back(tir::transform::CompactBufferAllocation()); - pass_list.push_back(tir::transform::LowerAutoCopy()); - pass_list.push_back(tir::transform::UnifyThreadBinding()); - pass_list.push_back(tir::transform::LowerMatchBuffer()); - pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::InjectPermutedLayout()); - pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::InjectSoftwarePipeline()); - pass_list.push_back(tir::transform::TransformMmaBufferLayout()); - pass_list.push_back(tir::transform::LowerOpaqueBlock()); - pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::BF16ComputeLegalize()); - pass_list.push_back(tir::transform::NarrowDataType(32)); - pass_list.push_back(tir::transform::Simplify()); - - // Add user-defined phase-1 passes - pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end()); - - // PHASE 2 - if (!disable_loop_partition) { - pass_list.push_back(tir::transform::LoopPartition()); - } - - pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); - pass_list.push_back(tir::transform::InjectVirtualThread()); - pass_list.push_back(tir::transform::InjectDoubleBuffer()); - if (!disable_storage_rewrite) { - pass_list.push_back(tir::transform::StorageRewrite()); - } - bool use_async_copy = pass_ctx->GetConfig("tir.use_async_copy", Bool(false)).value(); - - if (use_async_copy) { - pass_list.push_back(tir::transform::LowerAsyncDMA()); - } - // HoistIfThenElse must be applied before UnrollLoop - // because HoistIfThenElse could utilize for loop structure - // which might be unrolled in UnrollLoop - pass_list.push_back(tir::transform::HoistIfThenElse()); - pass_list.push_back(tir::transform::UnrollLoop()); - - // Add user-defined phase-2 passes - pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end()); - - // PHASE 3 - pass_list.push_back(tir::transform::RenormalizeSplitPattern()); - pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::RemoveNoOp()); - pass_list.push_back(tir::transform::RewriteUnsafeSelect()); - - // Add user-defined phase-3 passes - pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end()); - - if (instrument_bound_checkers) { - pass_list.push_back(tir::transform::InstrumentBoundCheckers()); - } - - if (ptx_ldg32) { - pass_list.push_back(tir::transform::InjectPTXLDG32(true)); - } - - pass_list.push_back( - tir::transform::CommonSubexprElimTIR(!disable_cse_tir, enable_equiv_terms_in_cse_tir)); - - // This pass instruments the loops with the profile builtin calls to capture the runtime - // performance data (only enabled for Hexagon at the moment). To ensure that no other - // optimizations are performed on the instrumented code, this pass must be added at the end - // of the list. - if (instrument_lwp) { - pass_list.push_back(tir::transform::InstrumentProfileIntrinsics()); - } - - return pass_list; -} - -IRModule LowerWithPassList(IRModule mod, Array pass_list) { - auto optimize = tvm::transform::Sequential(pass_list); - mod = optimize(std::move(mod)); - return mod; -} - -IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { - mod = seq(std::move(mod)); - return mod; -} - -IRModule LowerModule(IRModule mod, bool simple_mode) { - Array pass_list = CreatePassList(simple_mode); - return LowerWithPassList(std::move(mod), pass_list); -} - -TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) { - return LowerModule(std::move(mod), simple_mode); -}); - -IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_mode) { - transform::PassContext pass_ctx = transform::PassContext::Current(); - tir::PrimFunc f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); - - bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); - - if (noalias) { - f = WithAttr(std::move(f), "tir.noalias", Bool(true)); - } - IRModule mod = IRModule(Map({{GlobalVar(name), f}})); - - // Get the pass list - Array pass_list = CreatePassList(simple_mode); - return LowerWithPassList(std::move(mod), pass_list); -} - -TVM_REGISTER_GLOBAL("driver.lower_primfunc") - .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) { - return LowerPrimFunc(std::move(func), name, simple_mode); - }); - -/** - * This function takes the input module that contains both the device and host opts. - * Then, it applies transformation on the original module before splitting into separate modules for - * device and host. Then it also applies transformations on the new splitted modules. - */ -std::pair SplitMixedModule(IRModule mod_mixed, const Target& target_arg, - const Target& target_host_arg) { - Target target = target_arg, target_host = target_host_arg; - CheckAndUpdateHostConsistency(&target, &target_host); - - ICHECK(mod_mixed.defined()) << "This module must be defined"; - - mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target)); - - IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host)); - - IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target)); - - auto keys = target->GetKeys(); - - CheckAndUpdateHostConsistency(&target, &target_host); - - bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); - if (target_is_gpu && device_mod->functions.size() == 0) { - DLOG(WARNING) << "Specified target " << target->str() - << " but cannot find device code. Did you forget to bind?"; - } - - return {host_mod, device_mod}; -} - -/*! - * \brief Check and update host field of the given legacy heterogeneous targets and - * target host.Note that this function is for legacy target api compatibility issue only, - * not recommended for other use. - * \param ir_modules The pointer to a Map objects with keys being Target objects - * \param host The Target typed object for target host to be updated - */ -void CheckAndUpdateHostConsistency(Map* targets, Target* host) { - Map new_targets; - for (auto& it : *targets) { - auto target = it.first; - CheckAndUpdateHostConsistency(&target, host); - new_targets.Set(target, it.second); - } - *targets = new_targets; -} - -runtime::Module TIRToRuntime(const Map& inputs_arg, - const Target& target_host_arg) { - CHECK(inputs_arg.size()) << "TIRToRuntime expects at least one IRModule as input."; - std::vector device_modules; - Map inputs = inputs_arg; - Target target_host = target_host_arg; - - // Fetch previous defined target host in targets - CheckAndUpdateHostConsistency(&inputs, &target_host); - - if (!target_host.defined()) { - for (const auto& it : inputs) { - if (it.first->GetTargetDeviceType() == kDLCPU) { - target_host = it.first; - break; - } - } - } - - if (!target_host.defined()) { - target_host = DefaultTargetHost(target_host); - } - - // Update target host for all targets - CheckAndUpdateHostConsistency(&inputs, &target_host); - - // Take the attrs from the first module so the eventual modules have them. - // Ideally this would just be one unified module all the way through; - IRModule first_module = (*inputs.begin()).second; - IRModule mhost_all = IRModule(Map(), {}, first_module->attrs); - - ICHECK(mhost_all.defined()) << "The host module must be defined"; - - for (const auto& it : inputs) { - if (it.second.defined()) { - const Target& target = it.first; - const IRModule& ir_module = it.second; - auto pair = SplitMixedModule(ir_module, target, target_host); - auto& host_mod = pair.first; - auto& device_mod = pair.second; - - ICHECK(host_mod.defined()) << "The split host module must be defined"; - - ICHECK(mhost_all.defined()) << "The host module must be defined"; - - // We don't want library modules going back into host codegen - // unless they're supposed to. Here if we overrode the target host - // to allow lowering previously we check that it's meant to be placed - // back into the host Module. - bool overrides_host_target = - target->GetTargetDeviceType() == target_host->GetTargetDeviceType(); - bool non_host_target_kind = target->kind != target_host->kind; - if (overrides_host_target && non_host_target_kind) { - device_modules.push_back(codegen::Build(host_mod, it.first)); - } else { - mhost_all->Update(host_mod); - } - - if (device_mod->functions.size() != 0) { - device_modules.push_back(codegen::Build(device_mod, it.first)); - } - } - } - - runtime::Module mhost = codegen::Build(mhost_all, target_host); - for (const auto& it : device_modules) { - if (it.operator->()) { - mhost.Import(it); - } - } - - return mhost; -} - -TVM_REGISTER_GLOBAL("driver.tir_to_runtime") - .set_body_typed([](const Map& inputs_arg, Target host_target) { - return TIRToRuntime(inputs_arg, host_target); - }); - -// Build for heterogeneous execution when targets are specified as -// objects. This wrapper around the internal API is maintained for -// backwards compatibility. -runtime::Module build(const Map& input, const Target& target_host) { - return TIRToRuntime(input, target_host); -} - -// Build for heterogeneous execution when target is a string. -runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { - Map updated_inputs; - Target target_host = target_host_arg; - for (const auto& it : inputs_arg) { - Target target = Target(it.first); - CheckAndUpdateHostConsistency(&target, &target_host); - Optional device = target->GetAttr("device"); - if (device.defined() && device.value() == "vta") { - target = Target("ext_dev"); - } - updated_inputs.Set(target, it.second); - } - return TIRToRuntime(updated_inputs, target_host); -} - -// Build for homogeneous execution. -runtime::Module build(const IRModule& funcs, const Target& target_arg, - const Target& target_host_arg) { - auto target = target_arg, target_host = target_host_arg; - CheckAndUpdateHostConsistency(&target, &target_host); - // More maps of target and target host - Map inputs = {{target, funcs}}; - return TIRToRuntime(inputs, target_host); -} - -transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { - transform::PassContext pass_ctx = transform::PassContext::Current(); - - Array mixed_pass_list; - - // FPComputeLegalize uses the target attrs added by BindTarget, so it must come first - mixed_pass_list.push_back(tir::transform::BindTarget(target)); - mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize()); - - // VerifyVTCMLimit must occur before LowerVtcmAlloc - mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target)); - // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations - mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc()); - - mixed_pass_list.push_back(tir::transform::VerifyMemory()); - - mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); - - bool detect_global_barrier = - pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); - if (detect_global_barrier) { - mixed_pass_list.push_back(tir::transform::ThreadSync("global")); - } - - mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); - mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn")); - mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); - mixed_pass_list.push_back(tir::transform::InferFragment()); - mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); - - bool use_async_copy = pass_ctx->GetConfig("tir.use_async_copy", Bool(false)).value(); - - if (use_async_copy) { - mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy()); - } - - bool ptx_ldg32 = pass_ctx->GetConfig("tir.ptx_ldg32", Bool(false)).value(); - if (ptx_ldg32) { - mixed_pass_list.push_back(tir::transform::InjectPTXLDG32()); - } - - mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); - mixed_pass_list.push_back(tir::transform::SplitHostDevice()); - // MergeSharedMemoryAllocations must be applied after SplitHostDevice - // because the merged allocation site is at the beginning of each device function - mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations()); - - mixed_pass_list.push_back(tir::transform::MakePackedAPI()); - mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); - mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); - - mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); - - return transform::Sequential(mixed_pass_list); -} - -TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") - .set_body_typed([](IRModule mixed_mod, Target target) { - return MixedModulePassManager(mixed_mod, target); - }); - -transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { - transform::PassContext pass_ctx = transform::PassContext::Current(); - - Array host_pass_list; - - runtime::TypedPackedFunc fcond = [](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != - CallingConv::kDeviceKernelLaunch; - }; - host_pass_list.push_back(tir::transform::Filter(fcond)); - - ICHECK(mixed_mod.defined()) << "This module must be defined"; - - host_pass_list.push_back(tir::transform::BindTarget(target_host)); - - host_pass_list.push_back(tir::transform::LowerTVMBuiltin()); - host_pass_list.push_back(tir::transform::LowerCustomDatatypes()); - host_pass_list.push_back(tir::transform::LowerIntrin()); - host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); - host_pass_list.push_back(tir::transform::CombineContextCall()); - - return transform::Sequential(host_pass_list); -} - -TVM_REGISTER_GLOBAL("driver.host_mod_passes") - .set_body_typed([](IRModule mixed_mod, Target target_host) { - return HostModulePassManager(mixed_mod, target_host); - }); - -transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) { - Array device_pass_list; - runtime::TypedPackedFunc fcond = [](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == - CallingConv::kDeviceKernelLaunch; - }; - device_pass_list.push_back(tir::transform::Filter(fcond)); - - device_pass_list.push_back(tir::transform::BindTarget(target)); - - device_pass_list.push_back(tir::transform::LowerWarpMemory()); - device_pass_list.push_back(tir::transform::Simplify()); - device_pass_list.push_back(tir::transform::LowerCustomDatatypes()); - device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); - device_pass_list.push_back(tir::transform::LowerIntrin()); - - return transform::Sequential(device_pass_list); -} - -TVM_REGISTER_GLOBAL("driver.device_mod_passes") - .set_body_typed([](IRModule mixed_mod, Target target_host) { - return DeviceModulePassManager(mixed_mod, target_host); - }); - -} // namespace tvm diff --git a/src/driver/internal_driver_api.h b/src/driver/internal_driver_api.h deleted file mode 100644 index 3b7cc7c7f7fa..000000000000 --- a/src/driver/internal_driver_api.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/driver/driver_api.h - * \brief Internal compiler driver APIs to drive the compilation. - * - * This module provides functionality that may be called internally - * within TVM, but is not part of the public-facing API. - */ -#ifndef TVM_DRIVER_INTERNAL_DRIVER_API_H_ -#define TVM_DRIVER_INTERNAL_DRIVER_API_H_ - -#include -#include - -namespace tvm { - -/*! - * \brief Build a device and host module for a specific target from a map - * contains target to IRModule. This function is used - * for heterogeneous build. - * \param input The map contains target to an IRModule. - * \param target_host The target for building host code. To use the default, - * pass Target(). - * \return The built module that contains code for different processors. - */ -runtime::Module TIRToRuntime(const Map& input, const Target& target_host); - -} // namespace tvm - -#endif // TVM_DRIVER_INTERNAL_DRIVER_API_H_ diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 0ed80310eb97..c65dda7d597a 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -46,6 +46,25 @@ using tvm::runtime::TVMArgs; using tvm::runtime::TVMRetValue; TVM_REGISTER_PASS_CONFIG_OPTION("testing.immutable_module", Bool); +// Register build pipeline related options +TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); struct PassContextThreadLocalEntry { /*! \brief The default pass context. */ diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 8c0ddeb6c34d..18da88be805d 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -21,7 +21,6 @@ * \file src/relax/backend/vm/codegen_vm.cc * \brief A codegen to generate VM executable from a Relax IRModule. */ -#include #include #include #include diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index a92cf7c749a0..e3812ea8c101 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -21,7 +21,6 @@ * \file src/relax/backend/vm/codegen_tir.cc * \brief A codegen to generate VMTIR function(that can be compiled) from executable. */ -#include #include #include #include diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index 27931b601760..14f68da3e4c1 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -17,7 +17,6 @@ * under the License. */ -#include #include #include #include diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index ff193acf143e..91a1e806cefc 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -17,7 +17,6 @@ * under the License. */ -#include #include #include #include @@ -116,8 +115,9 @@ class ConstantFolder : public ExprMutator { // already scheduled to only work on GPU, we will need to skip this in the const folder for // now // TODO(Hongyi): further check and narrow the scope of foldable function - runtime::Module rt_module = - build(LowerPrimFunc(func, "tir_function"), eval_cpu_target, eval_cpu_target); + auto* pf = runtime::Registry::Get("tir.build"); + ICHECK(pf != nullptr) << "Cannot find tir.build in registry"; + runtime::Module rt_module = (*pf)(func, eval_cpu_target, "tir_function"); build_func = rt_module.GetFunction("tir_function"); } catch (const tvm::Error& err) { // build failure may happen in which case we skip diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index 7f45fee9a26c..d5946fda216f 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -22,7 +22,6 @@ * \brief Passes that serve as helper functions. */ -#include #include namespace tvm { From cff57c47e749a9bc205b78a5820a67cedf89d889 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Wed, 19 Feb 2025 01:32:32 +0000 Subject: [PATCH 06/17] dev --- python/tvm/driver/build_module.py | 19 +- python/tvm/tir/build.py | 557 ++++++++++++++++++++++++++++++ 2 files changed, 570 insertions(+), 6 deletions(-) create mode 100644 python/tvm/tir/build.py diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index b5f79d300c09..acc1393d167d 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -506,19 +506,23 @@ def build( tvm.runtime.Module A module combining both host and device code. """ + # Convert PrimFunc to IRModule + pass_ctx = tvm.ir.transform.PassContext.current() if isinstance(inputs, PrimFunc): - input_mod = lower(inputs, name=name) + f = inputs.with_attr("global_symbol", name) + if pass_ctx.config.get("tir.noalias", True): + f = f.with_attr("tir.noalias", True) + input_mod = tvm.ir.IRModule({tvm.ir.GlobalVar(name): f}) elif isinstance(inputs, tvm.IRModule): - if not inputs.get_global_vars(): - raise ValueError("Expected a non-empty IRModule.") - input_mod = lower(inputs) + input_mod = inputs else: raise ValueError("Inputs must be IRModule or PrimFunc") + # Get target and target_host target = Target.current() if target is None else target - if target is None and isinstance(input_mod, tvm.IRModule): + if target is None and isinstance(inputs, tvm.IRModule): target_mod = {} - for gvar, func in input_mod.functions.items(): + for gvar, func in inputs.functions.items(): tgt = func.attrs.get("target", "llvm") target_mod.setdefault(tgt, {})[gvar] = func target_input_mod = { @@ -546,6 +550,9 @@ def build( target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) + for tgt, mod in annotated_mods.items(): + mod = lower_module(mod, simple_mode=False) + annotated_mods[tgt] = mod return tir_to_runtime(annotated_mods, target_host) diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py new file mode 100644 index 000000000000..b62d799390a5 --- /dev/null +++ b/python/tvm/tir/build.py @@ -0,0 +1,557 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name +"""The build utils in python.""" +from typing import Union, Optional, Dict, Tuple +import enum +import tvm +from tvm import tir, ir +from tvm.runtime import ndarray +from tvm.tir import PrimFunc +from tvm.ir.module import IRModule +from tvm.target import Target +from tvm._ffi.runtime_ctypes import Device + + +def create_pass_list(disable_loop_partition: bool): + """Create a list of passes based on pass context configurations. + + Parameters + ---------- + disable_loop_partition : bool + Whether to disable loop partition pass. + + Returns + ------- + List[tvm.tir.transform.Pass] + List of passes to run. + """ + pass_ctx = tvm.transform.PassContext.current() + config = pass_ctx.config + # Retrieve configuration flags. + disable_vectorize = bool(config.get("tir.disable_vectorize", False)) + disable_storage_rewrite = bool(config.get("tir.disable_storage_rewrite", False)) + instrument_bound_checkers = bool(config.get("tir.instrument_bound_checkers", False)) + disable_cse_tir = bool(config.get("tir.disable_cse_tir", False)) + enable_equiv_terms_in_cse_tir = bool(config.get("tir.enable_equiv_terms_in_cse_tir", False)) + ptx_ldg32 = bool(config.get("tir.ptx_ldg32", False)) + instrument_lwp = bool(config.get("tir.instrument_lwp", False)) + add_lower_pass = config.get("tir.add_lower_pass", []) + + # Group user passes by phase (phases 0, 1, 2, and 3 where phase>=3 goes to 3) + user_passes = {0: [], 1: [], 2: [], 3: []} + for phase, p in add_lower_pass: + if not isinstance(phase, int) or phase < 0: + raise ValueError( + f"Phase number must be a non-negative integer, got {phase} of type {type(phase)}" + ) + user_passes[phase if phase < 3 else 3].append(p) + + # Construct phase-specific passes. + phase0 = user_passes[0] + + phase1 = [ + tir.transform.InjectPrefetch(), + tir.transform.TextureFlatten(), + tir.transform.StorageFlatten(64, instrument_bound_checkers), + tir.transform.LowerCrossThreadReduction(), + tir.transform.LowerInitBlock(), + tir.transform.PlanAndUpdateBufferAllocationLocation(), + tir.transform.ConvertBlocksToOpaque(), + tir.transform.LiftThreadBinding(), + tir.transform.ManifestSharedMemoryLocalStage(), + tir.transform.CompactBufferAllocation(), + tir.transform.LowerAutoCopy(), + tir.transform.UnifyThreadBinding(), + tir.transform.LowerMatchBuffer(), + tir.transform.Simplify(), + tir.transform.InjectPermutedLayout(), + tir.transform.Simplify(), + tir.transform.InjectSoftwarePipeline(), + tir.transform.TransformMmaBufferLayout(), + tir.transform.LowerOpaqueBlock(), + tir.transform.FlattenBuffer(), + tir.transform.BF16ComputeLegalize(), + tir.transform.NarrowDataType(32), + tir.transform.Simplify(), + ] + user_passes[1] + + phase2 = [] + if not disable_loop_partition: + phase2.append(tir.transform.LoopPartition()) + phase2.extend( + [ + tir.transform.VectorizeLoop(not disable_vectorize), + tir.transform.InjectVirtualThread(), + tir.transform.InjectDoubleBuffer(), + ] + ) + if not disable_storage_rewrite: + phase2.append(tir.transform.StorageRewrite()) + if config.get("tir.use_async_copy", False): + phase2.append(tir.transform.LowerAsyncDMA()) + phase2.extend( + [ + tir.transform.HoistIfThenElse(), + tir.transform.UnrollLoop(), + ] + ) + phase2 += user_passes[2] + + phase3 = [ + tir.transform.RenormalizeSplitPattern(), + tir.transform.Simplify(), + tir.transform.RemoveNoOp(), + tir.transform.RewriteUnsafeSelect(), + ] + user_passes[3] + + # Additional passes based on configuration. + extras = [] + if instrument_bound_checkers: + extras.append(tir.transform.InstrumentBoundCheckers()) + if ptx_ldg32: + extras.append(tir.transform.InjectPTXLDG32(True)) + extras.append( + tir.transform.CommonSubexprElimTIR(not disable_cse_tir, enable_equiv_terms_in_cse_tir) + ) + if instrument_lwp: + extras.append(tir.transform.InstrumentProfileIntrinsics()) + + return phase0 + phase1 + phase2 + phase3 + extras + + +def lower_module(inp: IRModule, simple_mode: bool = False) -> IRModule: + """Lowering step before building the target. + + Parameters + ---------- + inp : IRModule + The IRModule to be lowered. + simple_mode : bool + Whether to output only a simple, compact statement. + + Returns + ------- + IRModule + The lowered IRModule. + """ + return tvm.ir.transform.Sequential(create_pass_list(simple_mode))(inp) + + +def lower_primfunc(inp: PrimFunc, name: str = "main", simple_mode: bool = False) -> IRModule: + """Lowering step before building the target for a PrimFunc. + + Parameters + ---------- + inp : PrimFunc + The PrimFunc to be lowered. + name : str + The name of the resulting function. + simple_mode : bool + Whether to output only a simple, compact statement. + + Returns + ------- + IRModule + The lowered IRModule. + """ + pass_ctx = tvm.ir.transform.PassContext.current() + f = inp.with_attr("global_symbol", name) + if pass_ctx.config.get("tir.noalias", True): + f = f.with_attr("tir.noalias", True) + mod = tvm.ir.IRModule({tvm.ir.GlobalVar(name): f}) + return tvm.ir.transform.Sequential(create_pass_list(simple_mode))(mod) + + +def lower( + inp: Union[PrimFunc, IRModule], name: str = "main", simple_mode: bool = False +) -> IRModule: + """Lowering step before building the target. + + Parameters + ---------- + inp : Union[PrimFunc, IRModule] + The PrimFunc or IRModule to be lowered. + name : str + The name of the resulting function (if applicable). + simple_mode : bool + Whether to output only a simple, compact statement. + + Returns + ------- + IRModule + The lowered IRModule. + """ + if isinstance(inp, IRModule): + return lower_module(inp, simple_mode) + if isinstance(inp, PrimFunc): + return lower_primfunc(inp, name, simple_mode) + raise ValueError(f"Expected input to be IRModule or PrimFunc, but got {type(inp)}") + + +def check_and_update_host_consistency(targets: dict, host): + """ + Check and update the host field of the given legacy heterogeneous targets + for legacy target API compatibility. + + Parameters + ---------- + targets : dict + Dictionary mapping Target objects to IRModule objects. + host : Target + The target host to be updated. + """ + for tgt in list(targets): + if getattr(tgt, "host", None) is None: + tgt.host = host + + +def mixed_module_pass_manager(target: Target) -> tvm.ir.transform.Sequential: + """ + Constructs a Sequential transformation pass pipeline for a mixed module. + + Parameters + ---------- + target : Target + The target device for which the module is intended. + + Returns + ------- + tvm.ir.transform.Sequential + A sequential pass pipeline for the mixed module. + """ + pass_ctx = tvm.ir.transform.PassContext.current() + mixed_pass_list = [ + # Bind the target first so that target-specific attributes are available. + tir.transform.BindTarget(target), + tir.transform.FP8ComputeLegalize(), + # VerifyVTCMLimit must occur before LowerVtcmAlloc. + tir.transform.VerifyVTCMLimit(target), + tir.transform.LowerVtcmAlloc(), + tir.transform.VerifyMemory(), + tir.transform.AnnotateEntryFunc(), + ] + if pass_ctx.config.get("tir.detect_global_barrier", False): + mixed_pass_list.append(tir.transform.ThreadSync("global")) + mixed_pass_list.extend( + [ + tir.transform.ThreadSync("shared"), + tir.transform.ThreadSync("shared.dyn"), + tir.transform.ThreadSync("warp"), + tir.transform.InferFragment(), + tir.transform.LowerThreadAllreduce(), + ] + ) + if pass_ctx.config.get("tir.use_async_copy", False): + mixed_pass_list.append(tir.transform.InjectPTXAsyncCopy()) + if pass_ctx.config.get("tir.ptx_ldg32", False): + mixed_pass_list.append(tir.transform.InjectPTXLDG32()) + mixed_pass_list.extend( + [ + tir.transform.AnnotateDeviceRegions(), + tir.transform.SplitHostDevice(), + # MergeSharedMemoryAllocations must follow SplitHostDevice. + tir.transform.MergeSharedMemoryAllocations(), + tir.transform.MakePackedAPI(), + tir.transform.FP8StorageLegalize(), + tir.transform.BF16StorageLegalize(), + tir.transform.LowerDeviceKernelLaunch(), + ] + ) + return tvm.ir.transform.Sequential(mixed_pass_list) + + +class CallConv(enum.IntEnum): + """ + Enum representing different calling conventions. + Corresponds to the C++ tvm::ir::CallingConv enum. + """ + + kDefault = 0 + kCPackedFunc = 1 + kDeviceKernelLaunch = 2 + + +def host_module_pass_manager(target_host: Target) -> tvm.ir.transform.Sequential: + """ + Build a sequential pass pipeline for lowering the host part of a mixed module. + + Parameters + ---------- + target_host : Target + The host target for which to lower the module. + + Returns + ------- + tvm.ir.transform.Sequential + A sequential pass pipeline for host-specific transformations. + """ + host_pass_list = [ + # Filter out device kernel launches. + tir.transform.Filter( + lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) + != int(CallConv.kDeviceKernelLaunch) + ), + tir.transform.BindTarget(target_host), + tir.transform.LowerTVMBuiltin(), + tir.transform.LowerCustomDatatypes(), + tir.transform.LowerIntrin(), + tir.transform.LowerDeviceStorageAccessInfo(), + tir.transform.CombineContextCall(), + ] + return tvm.ir.transform.Sequential(host_pass_list) + + +def device_module_pass_manager(target: Target) -> tvm.ir.transform.Sequential: + """ + Build a sequential pass pipeline for lowering the device part of a mixed module. + + Parameters + ---------- + target : Target + The target for device-specific transformations. + + Returns + ------- + tvm.ir.transform.Sequential + A sequential pass pipeline for device-specific transformations. + """ + device_pass_list = [ + # Select only device kernel launches. + tir.transform.Filter( + lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) + == int(CallConv.kDeviceKernelLaunch) + ), + tir.transform.BindTarget(target), + tir.transform.LowerWarpMemory(), + tir.transform.Simplify(), + tir.transform.LowerCustomDatatypes(), + tir.transform.LowerDeviceStorageAccessInfo(), + tir.transform.LowerIntrin(), + ] + return tvm.ir.transform.Sequential(device_pass_list) + + +def split_mixed_module( + mod_mixed: IRModule, target_arg: Target, target_host_arg: Target +) -> Tuple[IRModule, IRModule]: + """ + Split a mixed module containing both device and host parts into separate modules, + applying appropriate transformations on each. + + Parameters + ---------- + mod_mixed : IRModule + The input module containing both device and host code. + target_arg : Target + The target for device-specific transformations. + target_host_arg : Target + The host target for lowering. + + Returns + ------- + Tuple[IRModule, IRModule] + (host module, device module) + """ + target, target_host = target_arg, target_host_arg + if getattr(target, "host", None) is None: + target.host = target_host + if mod_mixed is None: + raise ValueError("Module must be defined") + + mod_mixed = mixed_module_pass_manager(target)(mod_mixed) + host_mod = host_module_pass_manager(target_host)(mod_mixed) + device_mod = device_module_pass_manager(target)(mod_mixed) + + # Warn if target is GPU but no device code was generated. + if "gpu" in target.keys and len(device_mod.functions) == 0: + print( + f"Warning: Specified target {target} but cannot find device code. " + "Did you forget to bind?" + ) + + return host_mod, device_mod + + +def default_target_host(target: Target) -> Target: + """ + Determine the default target host for a given target. + """ + if target is not None and target.device_type == Device.kDLCPU: + return target + # In practice, llvm_enabled should be determined dynamically. + llvm_enabled = True + return Target("llvm") if llvm_enabled else Target("stackvm") + + +def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module: + """ + Build a runtime module from an IRModule and a Target. + + If the "tir.disable_assert" flag is set in the pass context, + the SkipAssert transformation is applied. + + Parameters + ---------- + mod : IRModule + The input IRModule. + target : Target + The target for which to build the module. + + Returns + ------- + tvm.runtime.Module + The built runtime module. + """ + if tvm.ir.transform.PassContext.current().config.get("tir.disable_assert", False): + mod = tvm.tir.transform.SkipAssert()(mod) + build_f_name = "target.build." + target.kind.name + bf = tvm.get_global_func(build_f_name) + if bf is None: + raise ValueError(f"{build_f_name} is not enabled") + return bf(mod, target) + + +def tir_to_runtime(inputs: Dict[Target, IRModule], target_host: Target): + """ + Convert a collection of TIR IRModules (keyed by Target) into a single runtime Module. + + Parameters + ---------- + inputs : dict + Mapping from Target to IRModule. + target_host : Target + The initial host target. + + Returns + ------- + tvm.runtime.Module + The final runtime module. + """ + if not inputs: + raise ValueError("TIRToRuntime expects at least one IRModule as input.") + + check_and_update_host_consistency(inputs, target_host) + if not target_host: + for tgt in inputs: + if tgt.get_target_device_type() == Device.kDLCPU: + target_host = tgt + break + if not target_host: + target_host = default_target_host(target_host) + check_and_update_host_consistency(inputs, target_host) + + first_module = next(iter(inputs.values())) + mhost_all = ir.IRModule({}, attrs=first_module.attrs) + if mhost_all is None: + raise ValueError("The host module must be defined") + + device_modules = [] + for tgt, ir_module in inputs.items(): + if ir_module: + host_mod, device_mod = split_mixed_module(ir_module, tgt, target_host) + overrides_host_target = ( + tgt.get_target_device_type() == target_host.get_target_device_type() + ) + non_host_target_kind = tgt.kind != target_host.kind + if overrides_host_target and non_host_target_kind: + device_modules.append(codegen_build(host_mod, tgt)) + else: + mhost_all.update(host_mod) + if len(device_mod.functions) != 0: + device_modules.append(codegen_build(device_mod, tgt)) + + mhost = codegen_build(mhost_all, target_host) + for dev_mod in device_modules: + if dev_mod is not None: + mhost.import_module(dev_mod) + return mhost + + +def build( + inputs: Union[PrimFunc, IRModule], + target: Optional[Union[str, Target]] = None, + name: str = "main", +): + """ + Build a function with a signature, generating code for devices + coupled with target information. + + Parameters + ---------- + inputs : Union[PrimFunc, IRModule] + The input to be built. + target : Optional[Union[str, Target]] + The target for compilation. + name : str + The name of the result function. + + Returns + ------- + tvm.runtime.Module + A module combining both host and device code. + """ + # Convert PrimFunc to IRModule + pass_ctx = tvm.ir.transform.PassContext.current() + if isinstance(inputs, PrimFunc): + f = inputs.with_attr("global_symbol", name) + if pass_ctx.config.get("tir.noalias", True): + f = f.with_attr("tir.noalias", True) + input_mod = tvm.ir.IRModule({tvm.ir.GlobalVar(name): f}) + elif isinstance(inputs, tvm.IRModule): + input_mod = inputs + else: + raise ValueError("Inputs must be IRModule or PrimFunc") + + # Get target and target_host + target = Target.current() if target is None else target + if target is None and isinstance(inputs, tvm.IRModule): + target_mod = {} + for gvar, func in inputs.functions.items(): + tgt = func.attrs.get("target", "llvm") + target_mod.setdefault(tgt, {})[gvar] = func + target_input_mod = { + tgt: tvm.IRModule(funcs).with_attrs(input_mod.attrs) + for tgt, funcs in target_mod.items() + } + else: + target_input_mod = {target: input_mod} + + annotated_mods = {} + for tgt, mod in target_input_mod.items(): + if not isinstance(tgt, (str, Target)): + raise ValueError("The key of inputs must be str or Target.") + if not isinstance(mod, tvm.IRModule): + raise ValueError("inputs must be IRModule, or dict of str to IRModule.") + annotated_mods[tgt] = mod + + annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods) + if not target_host: + for tar, mod in annotated_mods.items(): + if ndarray.device(tar.kind.name, 0).device_type == ndarray.cpu(0).device_type: + target_host = tar + break + if not target_host: + target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" + annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) + + input_mod = lower_module(input_mod, simple_mode=False) + return tir_to_runtime(annotated_mods, target_host) + + +tvm.register_func("tir.build", build) From 0489c62778f4336b0a8b99f921dfe6f859d5c7d9 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Wed, 19 Feb 2025 02:52:44 +0000 Subject: [PATCH 07/17] dev --- python/tvm/driver/build_module.py | 57 +++++++++++++------------------ 1 file changed, 24 insertions(+), 33 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index acc1393d167d..5f5db65b9d96 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -427,14 +427,14 @@ def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module: return bf(mod, target) -def tir_to_runtime(inputs: Dict[Target, IRModule], target_host: Target): +def tir_to_runtime(inputs: Dict[Target, Tuple[IRModule, IRModule]], target_host: Target): """ Convert a collection of TIR IRModules (keyed by Target) into a single runtime Module. Parameters ---------- inputs : dict - Mapping from Target to IRModule. + Mapping from Target to Tuple[IRModule, IRModule]. target_host : Target The initial host target. @@ -443,38 +443,22 @@ def tir_to_runtime(inputs: Dict[Target, IRModule], target_host: Target): tvm.runtime.Module The final runtime module. """ - if not inputs: - raise ValueError("TIRToRuntime expects at least one IRModule as input.") - check_and_update_host_consistency(inputs, target_host) - if not target_host: - for tgt in inputs: - if tgt.get_target_device_type() == Device.kDLCPU: - target_host = tgt - break - if not target_host: - target_host = default_target_host(target_host) - check_and_update_host_consistency(inputs, target_host) - - first_module = next(iter(inputs.values())) + # Get the first module to get the attributes + # necessary for tests/python/codegen/test_target_codegen_blob.py::test_cuda_multi_lib + first_module = next(iter(inputs.values()))[0] mhost_all = ir.IRModule({}, attrs=first_module.attrs) - if mhost_all is None: - raise ValueError("The host module must be defined") device_modules = [] - for tgt, ir_module in inputs.items(): - if ir_module: - host_mod, device_mod = split_mixed_module(ir_module, tgt, target_host) - overrides_host_target = ( - tgt.get_target_device_type() == target_host.get_target_device_type() - ) - non_host_target_kind = tgt.kind != target_host.kind - if overrides_host_target and non_host_target_kind: - device_modules.append(codegen_build(host_mod, tgt)) - else: - mhost_all.update(host_mod) - if len(device_mod.functions) != 0: - device_modules.append(codegen_build(device_mod, tgt)) + for tgt, (host_mod, device_mod) in inputs.items(): + overrides_host_target = tgt.get_target_device_type() == target_host.get_target_device_type() + non_host_target_kind = tgt.kind != target_host.kind + if overrides_host_target and non_host_target_kind: + device_modules.append(codegen_build(host_mod, tgt)) + else: + mhost_all.update(host_mod) + if len(device_mod.functions) != 0: + device_modules.append(codegen_build(device_mod, tgt)) mhost = codegen_build(mhost_all, target_host) for dev_mod in device_modules: @@ -520,9 +504,9 @@ def build( # Get target and target_host target = Target.current() if target is None else target - if target is None and isinstance(inputs, tvm.IRModule): + if target is None and isinstance(input_mod, tvm.IRModule): target_mod = {} - for gvar, func in inputs.functions.items(): + for gvar, func in input_mod.functions.items(): tgt = func.attrs.get("target", "llvm") target_mod.setdefault(tgt, {})[gvar] = func target_input_mod = { @@ -550,9 +534,16 @@ def build( target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) + assert annotated_mods is not None and target_host is not None + check_and_update_host_consistency(annotated_mods, target_host) + + # Lower the module for tgt, mod in annotated_mods.items(): mod = lower_module(mod, simple_mode=False) - annotated_mods[tgt] = mod + host_mod, device_mod = split_mixed_module(mod, tgt, target_host) + annotated_mods[tgt] = (host_mod, device_mod) + + # Convert TIR IRModules to runtime Module by calling target.build return tir_to_runtime(annotated_mods, target_host) From b3b5349d88d725498b8fa9459a73904d61501686 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Wed, 19 Feb 2025 16:38:12 +0000 Subject: [PATCH 08/17] remove lower --- python/tvm/__init__.py | 2 +- python/tvm/driver/__init__.py | 2 +- python/tvm/driver/build_module.py | 522 +----------------- python/tvm/tir/__init__.py | 1 + python/tvm/tir/build.py | 59 +- .../codegen/test_target_codegen_llvm.py | 4 - .../test_hexagon/test_2d_physical_buffers.py | 12 - .../test_benchmark_elemwise_add.py | 6 - .../test_software_pipeline_async.py | 1 - tests/python/ir/test_pass_instrument.py | 2 +- tests/python/tir-base/test_lower_build.py | 133 ----- .../tir-base/test_tir_te_extern_primfunc.py | 1 - .../test_tir_transform_convert_ssa.py | 35 -- .../test_tir_transform_extract_constants.py | 2 - .../test_tir_transform_flatten_buffer.py | 31 -- .../test_tir_transform_narrow_datatype.py | 31 +- .../test_tir_transform_storage_rewrite.py | 82 --- 17 files changed, 33 insertions(+), 893 deletions(-) delete mode 100644 tests/python/tir-base/test_lower_build.py diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index abbab3ad6d39..f4519f834d74 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -55,7 +55,7 @@ from . import te # tvm.driver -from .driver import build, lower +from .driver import build # others from . import arith diff --git a/python/tvm/driver/__init__.py b/python/tvm/driver/__init__.py index 75e94cc91c83..b97375c3a364 100644 --- a/python/tvm/driver/__init__.py +++ b/python/tvm/driver/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. """Namespace for driver APIs""" -from .build_module import lower, build +from .build_module import build diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 5f5db65b9d96..6001699f4ad1 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -17,454 +17,11 @@ # pylint: disable=invalid-name """The build utils in python.""" -from typing import Union, Optional, Dict, Tuple -import enum +from typing import Union, Optional import tvm -from tvm import tir, ir -from tvm.runtime import ndarray from tvm.tir import PrimFunc from tvm.ir.module import IRModule from tvm.target import Target -from tvm._ffi.runtime_ctypes import Device - - -def create_pass_list(disable_loop_partition: bool): - """Create a list of passes based on pass context configurations. - - Parameters - ---------- - disable_loop_partition : bool - Whether to disable loop partition pass. - - Returns - ------- - List[tvm.tir.transform.Pass] - List of passes to run. - """ - pass_ctx = tvm.transform.PassContext.current() - config = pass_ctx.config - # Retrieve configuration flags. - disable_vectorize = bool(config.get("tir.disable_vectorize", False)) - disable_storage_rewrite = bool(config.get("tir.disable_storage_rewrite", False)) - instrument_bound_checkers = bool(config.get("tir.instrument_bound_checkers", False)) - disable_cse_tir = bool(config.get("tir.disable_cse_tir", False)) - enable_equiv_terms_in_cse_tir = bool(config.get("tir.enable_equiv_terms_in_cse_tir", False)) - ptx_ldg32 = bool(config.get("tir.ptx_ldg32", False)) - instrument_lwp = bool(config.get("tir.instrument_lwp", False)) - add_lower_pass = config.get("tir.add_lower_pass", []) - - # Group user passes by phase (phases 0, 1, 2, and 3 where phase>=3 goes to 3) - user_passes = {0: [], 1: [], 2: [], 3: []} - for phase, p in add_lower_pass: - if not isinstance(phase, int) or phase < 0: - raise ValueError( - f"Phase number must be a non-negative integer, got {phase} of type {type(phase)}" - ) - user_passes[phase if phase < 3 else 3].append(p) - - # Construct phase-specific passes. - phase0 = user_passes[0] - - phase1 = [ - tir.transform.InjectPrefetch(), - tir.transform.TextureFlatten(), - tir.transform.StorageFlatten(64, instrument_bound_checkers), - tir.transform.LowerCrossThreadReduction(), - tir.transform.LowerInitBlock(), - tir.transform.PlanAndUpdateBufferAllocationLocation(), - tir.transform.ConvertBlocksToOpaque(), - tir.transform.LiftThreadBinding(), - tir.transform.ManifestSharedMemoryLocalStage(), - tir.transform.CompactBufferAllocation(), - tir.transform.LowerAutoCopy(), - tir.transform.UnifyThreadBinding(), - tir.transform.LowerMatchBuffer(), - tir.transform.Simplify(), - tir.transform.InjectPermutedLayout(), - tir.transform.Simplify(), - tir.transform.InjectSoftwarePipeline(), - tir.transform.TransformMmaBufferLayout(), - tir.transform.LowerOpaqueBlock(), - tir.transform.FlattenBuffer(), - tir.transform.BF16ComputeLegalize(), - tir.transform.NarrowDataType(32), - tir.transform.Simplify(), - ] + user_passes[1] - - phase2 = [] - if not disable_loop_partition: - phase2.append(tir.transform.LoopPartition()) - phase2.extend( - [ - tir.transform.VectorizeLoop(not disable_vectorize), - tir.transform.InjectVirtualThread(), - tir.transform.InjectDoubleBuffer(), - ] - ) - if not disable_storage_rewrite: - phase2.append(tir.transform.StorageRewrite()) - if config.get("tir.use_async_copy", False): - phase2.append(tir.transform.LowerAsyncDMA()) - phase2.extend( - [ - tir.transform.HoistIfThenElse(), - tir.transform.UnrollLoop(), - ] - ) - phase2 += user_passes[2] - - phase3 = [ - tir.transform.RenormalizeSplitPattern(), - tir.transform.Simplify(), - tir.transform.RemoveNoOp(), - tir.transform.RewriteUnsafeSelect(), - ] + user_passes[3] - - # Additional passes based on configuration. - extras = [] - if instrument_bound_checkers: - extras.append(tir.transform.InstrumentBoundCheckers()) - if ptx_ldg32: - extras.append(tir.transform.InjectPTXLDG32(True)) - extras.append( - tir.transform.CommonSubexprElimTIR(not disable_cse_tir, enable_equiv_terms_in_cse_tir) - ) - if instrument_lwp: - extras.append(tir.transform.InstrumentProfileIntrinsics()) - - return phase0 + phase1 + phase2 + phase3 + extras - - -def lower_module(inp: IRModule, simple_mode: bool = False) -> IRModule: - """Lowering step before building the target. - - Parameters - ---------- - inp : IRModule - The IRModule to be lowered. - simple_mode : bool - Whether to output only a simple, compact statement. - - Returns - ------- - IRModule - The lowered IRModule. - """ - return tvm.ir.transform.Sequential(create_pass_list(simple_mode))(inp) - - -def lower_primfunc(inp: PrimFunc, name: str = "main", simple_mode: bool = False) -> IRModule: - """Lowering step before building the target for a PrimFunc. - - Parameters - ---------- - inp : PrimFunc - The PrimFunc to be lowered. - name : str - The name of the resulting function. - simple_mode : bool - Whether to output only a simple, compact statement. - - Returns - ------- - IRModule - The lowered IRModule. - """ - pass_ctx = tvm.ir.transform.PassContext.current() - f = inp.with_attr("global_symbol", name) - if pass_ctx.config.get("tir.noalias", True): - f = f.with_attr("tir.noalias", True) - mod = tvm.ir.IRModule({tvm.ir.GlobalVar(name): f}) - return tvm.ir.transform.Sequential(create_pass_list(simple_mode))(mod) - - -def lower( - inp: Union[PrimFunc, IRModule], name: str = "main", simple_mode: bool = False -) -> IRModule: - """Lowering step before building the target. - - Parameters - ---------- - inp : Union[PrimFunc, IRModule] - The PrimFunc or IRModule to be lowered. - name : str - The name of the resulting function (if applicable). - simple_mode : bool - Whether to output only a simple, compact statement. - - Returns - ------- - IRModule - The lowered IRModule. - """ - if isinstance(inp, IRModule): - return lower_module(inp, simple_mode) - if isinstance(inp, PrimFunc): - return lower_primfunc(inp, name, simple_mode) - raise ValueError(f"Expected input to be IRModule or PrimFunc, but got {type(inp)}") - - -def check_and_update_host_consistency(targets: dict, host): - """ - Check and update the host field of the given legacy heterogeneous targets - for legacy target API compatibility. - - Parameters - ---------- - targets : dict - Dictionary mapping Target objects to IRModule objects. - host : Target - The target host to be updated. - """ - for tgt in list(targets): - if getattr(tgt, "host", None) is None: - tgt.host = host - - -def mixed_module_pass_manager(target: Target) -> tvm.ir.transform.Sequential: - """ - Constructs a Sequential transformation pass pipeline for a mixed module. - - Parameters - ---------- - target : Target - The target device for which the module is intended. - - Returns - ------- - tvm.ir.transform.Sequential - A sequential pass pipeline for the mixed module. - """ - pass_ctx = tvm.ir.transform.PassContext.current() - mixed_pass_list = [ - # Bind the target first so that target-specific attributes are available. - tir.transform.BindTarget(target), - tir.transform.FP8ComputeLegalize(), - # VerifyVTCMLimit must occur before LowerVtcmAlloc. - tir.transform.VerifyVTCMLimit(target), - tir.transform.LowerVtcmAlloc(), - tir.transform.VerifyMemory(), - tir.transform.AnnotateEntryFunc(), - ] - if pass_ctx.config.get("tir.detect_global_barrier", False): - mixed_pass_list.append(tir.transform.ThreadSync("global")) - mixed_pass_list.extend( - [ - tir.transform.ThreadSync("shared"), - tir.transform.ThreadSync("shared.dyn"), - tir.transform.ThreadSync("warp"), - tir.transform.InferFragment(), - tir.transform.LowerThreadAllreduce(), - ] - ) - if pass_ctx.config.get("tir.use_async_copy", False): - mixed_pass_list.append(tir.transform.InjectPTXAsyncCopy()) - if pass_ctx.config.get("tir.ptx_ldg32", False): - mixed_pass_list.append(tir.transform.InjectPTXLDG32()) - mixed_pass_list.extend( - [ - tir.transform.AnnotateDeviceRegions(), - tir.transform.SplitHostDevice(), - # MergeSharedMemoryAllocations must follow SplitHostDevice. - tir.transform.MergeSharedMemoryAllocations(), - tir.transform.MakePackedAPI(), - tir.transform.FP8StorageLegalize(), - tir.transform.BF16StorageLegalize(), - tir.transform.LowerDeviceKernelLaunch(), - ] - ) - return tvm.ir.transform.Sequential(mixed_pass_list) - - -class CallConv(enum.IntEnum): - """ - Enum representing different calling conventions. - Corresponds to the C++ tvm::ir::CallingConv enum. - """ - - kDefault = 0 - kCPackedFunc = 1 - kDeviceKernelLaunch = 2 - - -def host_module_pass_manager(target_host: Target) -> tvm.ir.transform.Sequential: - """ - Build a sequential pass pipeline for lowering the host part of a mixed module. - - Parameters - ---------- - target_host : Target - The host target for which to lower the module. - - Returns - ------- - tvm.ir.transform.Sequential - A sequential pass pipeline for host-specific transformations. - """ - host_pass_list = [ - # Filter out device kernel launches. - tir.transform.Filter( - lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) - != int(CallConv.kDeviceKernelLaunch) - ), - tir.transform.BindTarget(target_host), - tir.transform.LowerTVMBuiltin(), - tir.transform.LowerCustomDatatypes(), - tir.transform.LowerIntrin(), - tir.transform.LowerDeviceStorageAccessInfo(), - tir.transform.CombineContextCall(), - ] - return tvm.ir.transform.Sequential(host_pass_list) - - -def device_module_pass_manager(target: Target) -> tvm.ir.transform.Sequential: - """ - Build a sequential pass pipeline for lowering the device part of a mixed module. - - Parameters - ---------- - target : Target - The target for device-specific transformations. - - Returns - ------- - tvm.ir.transform.Sequential - A sequential pass pipeline for device-specific transformations. - """ - device_pass_list = [ - # Select only device kernel launches. - tir.transform.Filter( - lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) - == int(CallConv.kDeviceKernelLaunch) - ), - tir.transform.BindTarget(target), - tir.transform.LowerWarpMemory(), - tir.transform.Simplify(), - tir.transform.LowerCustomDatatypes(), - tir.transform.LowerDeviceStorageAccessInfo(), - tir.transform.LowerIntrin(), - ] - return tvm.ir.transform.Sequential(device_pass_list) - - -def split_mixed_module( - mod_mixed: IRModule, target_arg: Target, target_host_arg: Target -) -> Tuple[IRModule, IRModule]: - """ - Split a mixed module containing both device and host parts into separate modules, - applying appropriate transformations on each. - - Parameters - ---------- - mod_mixed : IRModule - The input module containing both device and host code. - target_arg : Target - The target for device-specific transformations. - target_host_arg : Target - The host target for lowering. - - Returns - ------- - Tuple[IRModule, IRModule] - (host module, device module) - """ - target, target_host = target_arg, target_host_arg - if getattr(target, "host", None) is None: - target.host = target_host - if mod_mixed is None: - raise ValueError("Module must be defined") - - mod_mixed = mixed_module_pass_manager(target)(mod_mixed) - host_mod = host_module_pass_manager(target_host)(mod_mixed) - device_mod = device_module_pass_manager(target)(mod_mixed) - - # Warn if target is GPU but no device code was generated. - if "gpu" in target.keys and len(device_mod.functions) == 0: - print( - f"Warning: Specified target {target} but cannot find device code. " - "Did you forget to bind?" - ) - - return host_mod, device_mod - - -def default_target_host(target: Target) -> Target: - """ - Determine the default target host for a given target. - """ - if target is not None and target.device_type == Device.kDLCPU: - return target - # In practice, llvm_enabled should be determined dynamically. - llvm_enabled = True - return Target("llvm") if llvm_enabled else Target("stackvm") - - -def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module: - """ - Build a runtime module from an IRModule and a Target. - - If the "tir.disable_assert" flag is set in the pass context, - the SkipAssert transformation is applied. - - Parameters - ---------- - mod : IRModule - The input IRModule. - target : Target - The target for which to build the module. - - Returns - ------- - tvm.runtime.Module - The built runtime module. - """ - if tvm.ir.transform.PassContext.current().config.get("tir.disable_assert", False): - mod = tvm.tir.transform.SkipAssert()(mod) - build_f_name = "target.build." + target.kind.name - bf = tvm.get_global_func(build_f_name) - if bf is None: - raise ValueError(f"{build_f_name} is not enabled") - return bf(mod, target) - - -def tir_to_runtime(inputs: Dict[Target, Tuple[IRModule, IRModule]], target_host: Target): - """ - Convert a collection of TIR IRModules (keyed by Target) into a single runtime Module. - - Parameters - ---------- - inputs : dict - Mapping from Target to Tuple[IRModule, IRModule]. - target_host : Target - The initial host target. - - Returns - ------- - tvm.runtime.Module - The final runtime module. - """ - - # Get the first module to get the attributes - # necessary for tests/python/codegen/test_target_codegen_blob.py::test_cuda_multi_lib - first_module = next(iter(inputs.values()))[0] - mhost_all = ir.IRModule({}, attrs=first_module.attrs) - - device_modules = [] - for tgt, (host_mod, device_mod) in inputs.items(): - overrides_host_target = tgt.get_target_device_type() == target_host.get_target_device_type() - non_host_target_kind = tgt.kind != target_host.kind - if overrides_host_target and non_host_target_kind: - device_modules.append(codegen_build(host_mod, tgt)) - else: - mhost_all.update(host_mod) - if len(device_mod.functions) != 0: - device_modules.append(codegen_build(device_mod, tgt)) - - mhost = codegen_build(mhost_all, target_host) - for dev_mod in device_modules: - if dev_mod is not None: - mhost.import_module(dev_mod) - return mhost def build( @@ -472,79 +29,4 @@ def build( target: Optional[Union[str, Target]] = None, name: str = "main", ): - """ - Build a function with a signature, generating code for devices - coupled with target information. - - Parameters - ---------- - inputs : Union[PrimFunc, IRModule] - The input to be built. - target : Optional[Union[str, Target]] - The target for compilation. - name : str - The name of the result function. - - Returns - ------- - tvm.runtime.Module - A module combining both host and device code. - """ - # Convert PrimFunc to IRModule - pass_ctx = tvm.ir.transform.PassContext.current() - if isinstance(inputs, PrimFunc): - f = inputs.with_attr("global_symbol", name) - if pass_ctx.config.get("tir.noalias", True): - f = f.with_attr("tir.noalias", True) - input_mod = tvm.ir.IRModule({tvm.ir.GlobalVar(name): f}) - elif isinstance(inputs, tvm.IRModule): - input_mod = inputs - else: - raise ValueError("Inputs must be IRModule or PrimFunc") - - # Get target and target_host - target = Target.current() if target is None else target - if target is None and isinstance(input_mod, tvm.IRModule): - target_mod = {} - for gvar, func in input_mod.functions.items(): - tgt = func.attrs.get("target", "llvm") - target_mod.setdefault(tgt, {})[gvar] = func - target_input_mod = { - tgt: tvm.IRModule(funcs).with_attrs(input_mod.attrs) - for tgt, funcs in target_mod.items() - } - else: - target_input_mod = {target: input_mod} - - annotated_mods = {} - for tgt, mod in target_input_mod.items(): - if not isinstance(tgt, (str, Target)): - raise ValueError("The key of inputs must be str or Target.") - if not isinstance(mod, tvm.IRModule): - raise ValueError("inputs must be IRModule, or dict of str to IRModule.") - annotated_mods[tgt] = mod - - annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods) - if not target_host: - for tar, mod in annotated_mods.items(): - if ndarray.device(tar.kind.name, 0).device_type == ndarray.cpu(0).device_type: - target_host = tar - break - if not target_host: - target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - - assert annotated_mods is not None and target_host is not None - check_and_update_host_consistency(annotated_mods, target_host) - - # Lower the module - for tgt, mod in annotated_mods.items(): - mod = lower_module(mod, simple_mode=False) - host_mod, device_mod = split_mixed_module(mod, tgt, target_host) - annotated_mods[tgt] = (host_mod, device_mod) - - # Convert TIR IRModules to runtime Module by calling target.build - return tir_to_runtime(annotated_mods, target_host) - - -tvm.register_func("tir.build", build) + return tvm.tir.build(inputs, target, name) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 1d7352f66527..fc1c76ad6f7e 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -108,3 +108,4 @@ from . import transform from . import analysis from . import stmt_functor +from .build import build \ No newline at end of file diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py index b62d799390a5..5f5db65b9d96 100644 --- a/python/tvm/tir/build.py +++ b/python/tvm/tir/build.py @@ -427,14 +427,14 @@ def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module: return bf(mod, target) -def tir_to_runtime(inputs: Dict[Target, IRModule], target_host: Target): +def tir_to_runtime(inputs: Dict[Target, Tuple[IRModule, IRModule]], target_host: Target): """ Convert a collection of TIR IRModules (keyed by Target) into a single runtime Module. Parameters ---------- inputs : dict - Mapping from Target to IRModule. + Mapping from Target to Tuple[IRModule, IRModule]. target_host : Target The initial host target. @@ -443,38 +443,22 @@ def tir_to_runtime(inputs: Dict[Target, IRModule], target_host: Target): tvm.runtime.Module The final runtime module. """ - if not inputs: - raise ValueError("TIRToRuntime expects at least one IRModule as input.") - check_and_update_host_consistency(inputs, target_host) - if not target_host: - for tgt in inputs: - if tgt.get_target_device_type() == Device.kDLCPU: - target_host = tgt - break - if not target_host: - target_host = default_target_host(target_host) - check_and_update_host_consistency(inputs, target_host) - - first_module = next(iter(inputs.values())) + # Get the first module to get the attributes + # necessary for tests/python/codegen/test_target_codegen_blob.py::test_cuda_multi_lib + first_module = next(iter(inputs.values()))[0] mhost_all = ir.IRModule({}, attrs=first_module.attrs) - if mhost_all is None: - raise ValueError("The host module must be defined") device_modules = [] - for tgt, ir_module in inputs.items(): - if ir_module: - host_mod, device_mod = split_mixed_module(ir_module, tgt, target_host) - overrides_host_target = ( - tgt.get_target_device_type() == target_host.get_target_device_type() - ) - non_host_target_kind = tgt.kind != target_host.kind - if overrides_host_target and non_host_target_kind: - device_modules.append(codegen_build(host_mod, tgt)) - else: - mhost_all.update(host_mod) - if len(device_mod.functions) != 0: - device_modules.append(codegen_build(device_mod, tgt)) + for tgt, (host_mod, device_mod) in inputs.items(): + overrides_host_target = tgt.get_target_device_type() == target_host.get_target_device_type() + non_host_target_kind = tgt.kind != target_host.kind + if overrides_host_target and non_host_target_kind: + device_modules.append(codegen_build(host_mod, tgt)) + else: + mhost_all.update(host_mod) + if len(device_mod.functions) != 0: + device_modules.append(codegen_build(device_mod, tgt)) mhost = codegen_build(mhost_all, target_host) for dev_mod in device_modules: @@ -520,9 +504,9 @@ def build( # Get target and target_host target = Target.current() if target is None else target - if target is None and isinstance(inputs, tvm.IRModule): + if target is None and isinstance(input_mod, tvm.IRModule): target_mod = {} - for gvar, func in inputs.functions.items(): + for gvar, func in input_mod.functions.items(): tgt = func.attrs.get("target", "llvm") target_mod.setdefault(tgt, {})[gvar] = func target_input_mod = { @@ -550,7 +534,16 @@ def build( target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - input_mod = lower_module(input_mod, simple_mode=False) + assert annotated_mods is not None and target_host is not None + check_and_update_host_consistency(annotated_mods, target_host) + + # Lower the module + for tgt, mod in annotated_mods.items(): + mod = lower_module(mod, simple_mode=False) + host_mod, device_mod = split_mixed_module(mod, tgt, target_host) + annotated_mods[tgt] = (host_mod, device_mod) + + # Convert TIR IRModules to runtime Module by calling target.build return tir_to_runtime(annotated_mods, target_host) diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index e3ccff49ba1b..7910cd372ffc 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -557,9 +557,6 @@ def _show_info(): print("dtype: {}".format(dtype)) print("dividend range: [{}, {}]".format(start, end)) print("divisor range: [{}, {}]".format(dstart, dend)) - lowered = tvm.lower(sch.mod, simple_mode=True) - print("Lowered code:") - print(lowered) # Check that the computed values are correct for i in range(start, end + 1): @@ -793,7 +790,6 @@ def _transform(f, *_): return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize") with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, my_vectorize())]}): - ir = tvm.lower(sch.mod, simple_mode=True) module = tvm.build(sch.mod) a_ = tvm.nd.array(np.arange(1, 9, dtype="int32")) b_ = tvm.nd.array(np.arange(8, 0, -1, dtype="int32")) diff --git a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py index 99fc6ac074c2..169d868b5479 100644 --- a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py +++ b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py @@ -244,15 +244,6 @@ def apply_transform(block, buffer_name, layout): return [sch.mod] - @tvm.testing.fixture - def ir_module(self, schedule_args): - # If the two buffers are accessed with the same indices, CSE - # will replace them with a Let binding. Since this makes it - # harder to test what the transformed indices are, disabling - # the CSE pass for this test. - with tvm.transform.PassContext(disabled_pass=["tir.CommonSubexprElimTIR"]): - return tvm.lower(*schedule_args) - @tvm.testing.fixture def uses_unsupported_physical_dimensions( # pylint: disable=invalid-name self, target_host, input_layout, working_layout, output_layout @@ -291,9 +282,6 @@ def test_cache_shape(self, ir_module, input_layout, working_layout, output_layou assert len(buffer.shape) == expected_physical_dimensions - def test_lower(self, schedule_args): - assert tvm.lower(*schedule_args) - @requires_hexagon_toolchain def test_build(self, schedule_args, target_host, input_layout, working_layout, output_layout): """Testing build success/failure diff --git a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py index a927532c8f4a..f0cefa3fe256 100644 --- a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py +++ b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py @@ -199,12 +199,6 @@ def _benchmark_hexagon_elementwise_add_kernel( try: ns_tir_module = _get_irmod_elemwise_add(shape, dtype, mem_scope) - # Dump the primfunc NS-TIR (as text) to the log file... - lowered_mod = tvm.lower(ns_tir_module, _PRIMFUNC_NAME) - log_file.write("LOWERED IR MODULE:\n") - log_file.write(str(lowered_mod)) - log_file.write("\n") - # Lower the primfunc's IRModule to Hexagon object code... input1 = tvm.te.placeholder(shape, dtype=dtype) input2 = tvm.te.placeholder(shape, dtype=dtype) diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index 498e29e407b4..d45b35befd11 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -183,7 +183,6 @@ def test_async_software_pipeline( "tir.experimental_dma_bypass_cache": 1, } ): - # tvm.lower(schedule.mod["main"]).show() func = tvm.build(schedule.mod["main"], target=get_hexagon_target("v68")) with hexagon_launcher.create_session() as hexagon_session: diff --git a/tests/python/ir/test_pass_instrument.py b/tests/python/ir/test_pass_instrument.py index cfeb70b96388..718cf3a663e5 100644 --- a/tests/python/ir/test_pass_instrument.py +++ b/tests/python/ir/test_pass_instrument.py @@ -38,7 +38,7 @@ def func(a: T.handle, b: T.handle) -> None: B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 with tvm.transform.PassContext(opt_level=3, instruments=[PrintBeforeAll(), PrintAfterAll()]): - tvm.lower(func) + tvm.build(func) all_passes_output = capsys.readouterr().out assert "Before Running Pass:" in all_passes_output assert "After Running Pass:" in all_passes_output diff --git a/tests/python/tir-base/test_lower_build.py b/tests/python/tir-base/test_lower_build.py deleted file mode 100644 index edb3ed351e5d..000000000000 --- a/tests/python/tir-base/test_lower_build.py +++ /dev/null @@ -1,133 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import numpy as np - -import tvm -from tvm.ir.module import IRModule -from tvm.script import tir as T -import tvm.testing - - -def _check_module_with_numpy(mod, shape=(128, 128, 128)): - m, n, k = shape - a = tvm.nd.array(np.random.rand(m, k).astype("float32")) - b = tvm.nd.array(np.random.rand(n, k).astype("float32")) - c = tvm.nd.array(np.zeros((m, n), dtype="float32")) - c_np = np.dot(a.numpy(), b.numpy().transpose()) - mod(a, b, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) - - -# pylint: disable=no-self-argument, missing-class-docstring, missing-function-docstring -@T.prim_func -def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) - for i, j in T.grid(128, 128): - with T.block("init"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = T.float32(0) - for k in range(128): - with T.block("update"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] - - -@tvm.script.ir_module -class LoweredModule: - @T.prim_func - def main( - A: T.Buffer((128, 128), "float32"), - B: T.Buffer((128, 128), "float32"), - C: T.Buffer((128, 128), "float32"), - ) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}) - A_flat = T.Buffer([16384], data=A.data) - B_flat = T.Buffer([16384], data=B.data) - C_flat = T.Buffer([16384], data=C.data) - # body - for x, y in T.grid(128, 128): - C_flat[x * 128 + y] = 0.0 - for k in T.serial(0, 128): - C_flat[x * 128 + y] = ( - C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128 + k] - ) - - -@tvm.script.ir_module -class LoweredTIRModule: - @T.prim_func - def main( - A: T.Buffer((128, 128), "float32"), - B: T.Buffer((128, 128), "float32"), - C: T.Buffer((128, 128), "float32"), - ) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_flat = T.Buffer([16384], data=A.data) - B_flat = T.Buffer([16384], data=B.data) - C_flat = T.Buffer([16384], data=C.data) - # body - for x, y in T.grid(128, 128): - C_flat[x * 128 + y] = 0.0 - for k in T.serial(0, 128): - C_flat[x * 128 + y] = ( - C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128 + k] - ) - - -def test_lower_build_tir_func(): - # check lowering with the CSE pass disabled as otherwise it would do some commoning - with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): - ir_mod = tvm.lower(matmul) - tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule) - # check building - mod = tvm.build(matmul, target="llvm") - _check_module_with_numpy(mod) - - -def test_lower_build_tir_module(): - func = matmul.with_attr("global_symbol", "main") - func = func.with_attr("tir.noalias", T.bool(True)) - ir_mod = IRModule({"main": func}) - # check lowering with the CSE pass disabled as otherwise it would do some commoning - with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): - lowered_mod = tvm.lower(ir_mod) - tvm.ir.assert_structural_equal(lowered_mod, LoweredTIRModule) - # check building - mod = tvm.build(ir_mod, target="llvm") - _check_module_with_numpy(mod) - - -def test_lower_build_lowered_module(): - # check lowering with the CSE pass disabled as otherwise it would do some commoning - with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): - ir_mod = tvm.lower(LoweredTIRModule) - tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule) - # check building - mod = tvm.build(ir_mod, target="llvm") - _check_module_with_numpy(mod) - - -if __name__ == "__main__": - test_lower_build_te_schedule() - test_lower_build_tir_func() - test_lower_build_tir_module() - test_lower_build_lowered_module() diff --git a/tests/python/tir-base/test_tir_te_extern_primfunc.py b/tests/python/tir-base/test_tir_te_extern_primfunc.py index 45ca7a1c7256..16bc0b0ae2fc 100644 --- a/tests/python/tir-base/test_tir_te_extern_primfunc.py +++ b/tests/python/tir-base/test_tir_te_extern_primfunc.py @@ -192,7 +192,6 @@ def test_te_extern_call(self, func, params, verify): input_tensors = [te.placeholder(buf_name_map[name].shape) for name in params] output = te.extern_primfunc(input_tensors, prim_func) rt_prim_func = te.create_prim_func(tensors_from_extern_op(output, prim_func)) - tvm.ir.assert_structural_equal(tvm.lower(prim_func), tvm.lower(rt_prim_func)) target = tvm.target.Target("llvm") func = tvm.build(rt_prim_func, target=target) diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py b/tests/python/tir-transform/test_tir_transform_convert_ssa.py index ec768ba74f7b..b93747c84a09 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py +++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py @@ -234,41 +234,6 @@ def func(A: T.Buffer(1, "float32")): assert before.same_as(after) -class TestDedupAutoBroadcastBuffer(BaseBeforeAfter): - """De-dup auto-broadcast buffers - - Auto-broadcast buffers can define additional variables during the - `Buffer::Buffer` constructor for the strides. This is intended to - be used for match buffers, where these variables are defined based - on the argument being passed in. - - These additional variables can cause errors when copying a buffer - with the `Buffer::Buffer` constructor. If a buffer has non-empty - shape, empty strides, and kAutoBroadcast type, then the resulting - buffer will have additional strides defined. Such a buffer can - result from lowering of a scalar buffer, which will be flattened - to a shape of [1]. - - Previous implementations of ConvertSSA incorrectly handled this - case, resulting in undefined stride variables. - """ - - def _make_func(self): - @T.prim_func - def func(a: T.handle): - A = T.match_buffer(a, shape=(), dtype="float32", buffer_type="auto") - A[()] = 1.0 - - return tvm.lower(func)["main"] - - def before(self): - func = self._make_func() - return tvm.IRModule({"func_a": func, "func_b": func}) - - def expected(self): - return tvm.IRModule({"func_a": self._make_func(), "func_b": self._make_func()}) - - class TestKeepDuplicateThreadIdxInSameFunction(BaseBeforeAfter): """Environment threads are treated as being at function scope diff --git a/tests/python/tir-transform/test_tir_transform_extract_constants.py b/tests/python/tir-transform/test_tir_transform_extract_constants.py index b3e0aa74f96d..cbfb6d39bcd2 100644 --- a/tests/python/tir-transform/test_tir_transform_extract_constants.py +++ b/tests/python/tir-transform/test_tir_transform_extract_constants.py @@ -63,8 +63,6 @@ def _visit(stmt): for n, f in mod.functions.items(): tvm.tir.stmt_functor.post_order_visit(f.body, _visit) - tvm.lower(mod) - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py index b215398622cc..925f004cc527 100644 --- a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py @@ -322,36 +322,5 @@ def expected(): T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5]) -def test_lower_2d_physical_memory(): - """Axis separators should preserve 2-d buffers through lowering. - - A catch-all test to ensure that defining axis_separators is - sufficient to maintain non-flat buffer descriptions through all - lowering steps. - """ - - # This test doesn't use CompareBeforeAfter, because the after step - # is not currently expressible in TVMScript. This test can be - # re-written after https://github.com/apache/tvm/pull/12412. - - @T.prim_func - def func(): - buf = T.alloc_buffer( - [1, 1], - dtype="int32", - scope="global", - axis_separators=[1], - ) - buf[0, 0] = 0 - - lowered = tvm.lower(func)["main"] - assert isinstance(lowered.body, tvm.tir.Allocate) - assert list(lowered.body.extents) == [1, 1], ( - "Non-flat buffer allocations, " - "marked by axis_separators, " - "flattened to flat memory allocation." - ) - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py index 93c680c846c5..e5cb3667633c 100644 --- a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py +++ b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py @@ -111,33 +111,6 @@ def check(m, n, target_bits, target_dtype): check(2**14, 32, target_bits=16, target_dtype="int32") -def test_thread_axis_2(): - # fmt: off - @tvm.script.ir_module - class Before: - @T.prim_func - def main(T_reshape: T.Buffer((1, 12, 384, 384), "float32"), placeholder_1: T.Buffer((T.int64(1), T.int64(12), T.int64(384), 384), "bool"), T_where: T.Buffer((T.int64(1), T.int64(12), T.int64(384), 384), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): - for i0_i1_i2_i3_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): - for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)): - with T.block("T_where"): - ax0 = T.axis.spatial(T.int64(1), T.int64(0)) - ax1 = T.axis.spatial(T.int64(12), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(1769472) // T.int64(147456)) - ax2 = T.axis.spatial(T.int64(384), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(147456) // T.int64(384)) - ax3 = T.axis.spatial(384, T.cast(((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(384), "int32")) - T.where((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2 < T.int64(1769472)) - T.reads(placeholder_1[ax0, ax1, ax2, ax3], T_reshape[ax0, ax1, ax2, ax3]) - T.writes(T_where[ax0, ax1, ax2, ax3]) - T_where[ax0, ax1, ax2, ax3] = T.Select(T.cast(placeholder_1[ax0, ax1, ax2, ax3], "int32") != 0, T.float32(-1000000000), T_reshape[ax0, ax1, ax2, ax3]) - # fmt: on - # TODO(@junrushao1994): make this test more "unit" after the new TVMScript printer/parser lands - tvm.lower(Before) - - def test_multilanes(): def check(m, lanes, target_bits, target_dtype): ib = tvm.tir.ir_builder.create() @@ -181,9 +154,7 @@ def check(m, n, target_bits, target_dtype): # The maximum index is (2**15 * 2**15 - 1) * 2 <= 2**31 - 1 check(const(2**15, "int64"), const(2**15, "int64"), target_bits=32, target_dtype="int32") # The maximum index is (2**15 * 2**15 - 1 + 2**15) * 2 > 2**31 - 1 - check( - const(2**15, "int64"), const((2**15 + 1), "int64"), target_bits=32, target_dtype="int64" - ) + check(const(2**15, "int64"), const((2**15 + 1), "int64"), target_bits=32, target_dtype="int64") def test_condition(): diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py index ab91c6c7b330..548b199a94ce 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py @@ -610,87 +610,5 @@ def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")): D[i] = C[i] -def test_vulkan_smem_reuse(): - target = tvm.target.Target( - { - "keys": ["vulkan", "gpu"], - "kind": "vulkan", - "max_num_threads": 256, - "max_threads_per_block": 256, - "supports_float32": True, - "supports_int32": True, - "tag": "", - "thread_warp_size": 1, - } - ) - - @T.prim_func(private=True) - def func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - A_shared = T.allocate([4], "float32", "shared") - A_local = T.allocate([4], "float32", "local") - B_shared = T.allocate([4], "float16", "shared") - A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_1 = T.Buffer((4,), data=A.data) - A_shared_1[threadIdx_x] = A_1[threadIdx_x] - A_local_1 = T.Buffer((4,), data=A_local, scope="local") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_local_1[threadIdx_x] = A_shared_1[threadIdx_x] - B_shared_1 = T.Buffer((4,), "float16", data=B_shared, scope="shared") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - B_shared_1[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x]) - threadIdx_x = T.launch_thread("threadIdx.x", 4) - B_1 = T.Buffer((4,), "float16", data=B.data) - B_1[threadIdx_x] = B_shared_1[threadIdx_x] - - @T.prim_func(private=True) - def normal_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - A_shared = T.allocate([4], "float32", "shared") - A_local = T.allocate([4], "float32", "local") - A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_1 = T.Buffer((4,), data=A.data) - A_shared_1[threadIdx_x] = A_1[threadIdx_x] - A_local_1 = T.Buffer((4,), data=A_local, scope="local") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_local_1[threadIdx_x] = A_shared_1[threadIdx_x] - A_shared_2 = T.Buffer((4,), "float16", data=A_shared, scope="shared") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_shared_2[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x]) - threadIdx_x = T.launch_thread("threadIdx.x", 4) - B_1 = T.Buffer((4,), "float16", data=B.data) - B_1[threadIdx_x] = A_shared_2[threadIdx_x] - - @T.prim_func(private=True) - def no_reuse_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")): - T.func_attr({"target": target, "tir.noalias": T.bool(True)}) - A_shared_1 = T.allocate([4], "float32", "shared") - A_local_1 = T.allocate([4], "float32", "local") - B_shared_1 = T.allocate([4], "float16", "shared") - A_shared_1_1 = T.Buffer((4,), data=A_shared_1, scope="shared") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_1 = T.Buffer((4,), data=A.data) - A_shared_1_1[threadIdx_x] = A_1[threadIdx_x] - A_local_1_1 = T.Buffer((4,), data=A_local_1, scope="local") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_local_1_1[threadIdx_x] = A_shared_1_1[threadIdx_x] - B_shared_1_1 = T.Buffer((4,), "float16", data=B_shared_1, scope="shared") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - B_shared_1_1[threadIdx_x] = T.Cast("float16", A_local_1_1[threadIdx_x]) - threadIdx_x = T.launch_thread("threadIdx.x", 4) - B_1 = T.Buffer((4,), "float16", data=B.data) - B_1[threadIdx_x] = B_shared_1_1[threadIdx_x] - - # Reuse shared memory when lowering without target. - mod = tvm.IRModule({"main": func}) - tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], normal_lowering) - - # No shared memory reuse when lowering with target Vulkan. - mod = tvm.tir.transform.BindTarget(target)(mod) - tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], no_reuse_lowering) - - if __name__ == "__main__": tvm.testing.main() From c352797ff24ce3ce45c3e5a169206934e88d4f37 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Wed, 19 Feb 2025 18:01:22 +0000 Subject: [PATCH 09/17] remove lower --- docs/reference/api/python/driver.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/reference/api/python/driver.rst b/docs/reference/api/python/driver.rst index 1f1bc8c7cf7b..97c30ec2d25b 100644 --- a/docs/reference/api/python/driver.rst +++ b/docs/reference/api/python/driver.rst @@ -19,6 +19,4 @@ tvm.driver ---------- .. automodule:: tvm.driver -.. autofunction:: tvm.lower - .. autofunction:: tvm.build From 068f36d406078c9f8558e26d624cdd7a579c7326 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 20 Feb 2025 02:40:14 +0000 Subject: [PATCH 10/17] reorganize --- python/tvm/driver/build_module.py | 6 +- python/tvm/tir/__init__.py | 3 +- python/tvm/tir/build.py | 496 ++---------------- python/tvm/tir/pipeline.py | 197 +++++++ python/tvm/tir/transform/transform.py | 2 +- src/ir/transform.cc | 19 - src/tir/ir/transform.cc | 20 + .../codegen/test_target_codegen_llvm.py | 37 -- 8 files changed, 280 insertions(+), 500 deletions(-) create mode 100644 python/tvm/tir/pipeline.py diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 6001699f4ad1..8d6a2a534389 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -25,8 +25,8 @@ def build( - inputs: Union[PrimFunc, IRModule], + mod: Union[PrimFunc, IRModule], target: Optional[Union[str, Target]] = None, - name: str = "main", + pipeline: Optional[Union[str, tvm.transform.Pass]] = "default_tir", ): - return tvm.tir.build(inputs, target, name) + return tvm.tir.build(mod, target, pipeline) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index fc1c76ad6f7e..9ff5bff5f1ff 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -108,4 +108,5 @@ from . import transform from . import analysis from . import stmt_functor -from .build import build \ No newline at end of file +from .build import build +from .pipeline import get_pipeline diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py index 5f5db65b9d96..516eb5d1d114 100644 --- a/python/tvm/tir/build.py +++ b/python/tvm/tir/build.py @@ -28,377 +28,6 @@ from tvm._ffi.runtime_ctypes import Device -def create_pass_list(disable_loop_partition: bool): - """Create a list of passes based on pass context configurations. - - Parameters - ---------- - disable_loop_partition : bool - Whether to disable loop partition pass. - - Returns - ------- - List[tvm.tir.transform.Pass] - List of passes to run. - """ - pass_ctx = tvm.transform.PassContext.current() - config = pass_ctx.config - # Retrieve configuration flags. - disable_vectorize = bool(config.get("tir.disable_vectorize", False)) - disable_storage_rewrite = bool(config.get("tir.disable_storage_rewrite", False)) - instrument_bound_checkers = bool(config.get("tir.instrument_bound_checkers", False)) - disable_cse_tir = bool(config.get("tir.disable_cse_tir", False)) - enable_equiv_terms_in_cse_tir = bool(config.get("tir.enable_equiv_terms_in_cse_tir", False)) - ptx_ldg32 = bool(config.get("tir.ptx_ldg32", False)) - instrument_lwp = bool(config.get("tir.instrument_lwp", False)) - add_lower_pass = config.get("tir.add_lower_pass", []) - - # Group user passes by phase (phases 0, 1, 2, and 3 where phase>=3 goes to 3) - user_passes = {0: [], 1: [], 2: [], 3: []} - for phase, p in add_lower_pass: - if not isinstance(phase, int) or phase < 0: - raise ValueError( - f"Phase number must be a non-negative integer, got {phase} of type {type(phase)}" - ) - user_passes[phase if phase < 3 else 3].append(p) - - # Construct phase-specific passes. - phase0 = user_passes[0] - - phase1 = [ - tir.transform.InjectPrefetch(), - tir.transform.TextureFlatten(), - tir.transform.StorageFlatten(64, instrument_bound_checkers), - tir.transform.LowerCrossThreadReduction(), - tir.transform.LowerInitBlock(), - tir.transform.PlanAndUpdateBufferAllocationLocation(), - tir.transform.ConvertBlocksToOpaque(), - tir.transform.LiftThreadBinding(), - tir.transform.ManifestSharedMemoryLocalStage(), - tir.transform.CompactBufferAllocation(), - tir.transform.LowerAutoCopy(), - tir.transform.UnifyThreadBinding(), - tir.transform.LowerMatchBuffer(), - tir.transform.Simplify(), - tir.transform.InjectPermutedLayout(), - tir.transform.Simplify(), - tir.transform.InjectSoftwarePipeline(), - tir.transform.TransformMmaBufferLayout(), - tir.transform.LowerOpaqueBlock(), - tir.transform.FlattenBuffer(), - tir.transform.BF16ComputeLegalize(), - tir.transform.NarrowDataType(32), - tir.transform.Simplify(), - ] + user_passes[1] - - phase2 = [] - if not disable_loop_partition: - phase2.append(tir.transform.LoopPartition()) - phase2.extend( - [ - tir.transform.VectorizeLoop(not disable_vectorize), - tir.transform.InjectVirtualThread(), - tir.transform.InjectDoubleBuffer(), - ] - ) - if not disable_storage_rewrite: - phase2.append(tir.transform.StorageRewrite()) - if config.get("tir.use_async_copy", False): - phase2.append(tir.transform.LowerAsyncDMA()) - phase2.extend( - [ - tir.transform.HoistIfThenElse(), - tir.transform.UnrollLoop(), - ] - ) - phase2 += user_passes[2] - - phase3 = [ - tir.transform.RenormalizeSplitPattern(), - tir.transform.Simplify(), - tir.transform.RemoveNoOp(), - tir.transform.RewriteUnsafeSelect(), - ] + user_passes[3] - - # Additional passes based on configuration. - extras = [] - if instrument_bound_checkers: - extras.append(tir.transform.InstrumentBoundCheckers()) - if ptx_ldg32: - extras.append(tir.transform.InjectPTXLDG32(True)) - extras.append( - tir.transform.CommonSubexprElimTIR(not disable_cse_tir, enable_equiv_terms_in_cse_tir) - ) - if instrument_lwp: - extras.append(tir.transform.InstrumentProfileIntrinsics()) - - return phase0 + phase1 + phase2 + phase3 + extras - - -def lower_module(inp: IRModule, simple_mode: bool = False) -> IRModule: - """Lowering step before building the target. - - Parameters - ---------- - inp : IRModule - The IRModule to be lowered. - simple_mode : bool - Whether to output only a simple, compact statement. - - Returns - ------- - IRModule - The lowered IRModule. - """ - return tvm.ir.transform.Sequential(create_pass_list(simple_mode))(inp) - - -def lower_primfunc(inp: PrimFunc, name: str = "main", simple_mode: bool = False) -> IRModule: - """Lowering step before building the target for a PrimFunc. - - Parameters - ---------- - inp : PrimFunc - The PrimFunc to be lowered. - name : str - The name of the resulting function. - simple_mode : bool - Whether to output only a simple, compact statement. - - Returns - ------- - IRModule - The lowered IRModule. - """ - pass_ctx = tvm.ir.transform.PassContext.current() - f = inp.with_attr("global_symbol", name) - if pass_ctx.config.get("tir.noalias", True): - f = f.with_attr("tir.noalias", True) - mod = tvm.ir.IRModule({tvm.ir.GlobalVar(name): f}) - return tvm.ir.transform.Sequential(create_pass_list(simple_mode))(mod) - - -def lower( - inp: Union[PrimFunc, IRModule], name: str = "main", simple_mode: bool = False -) -> IRModule: - """Lowering step before building the target. - - Parameters - ---------- - inp : Union[PrimFunc, IRModule] - The PrimFunc or IRModule to be lowered. - name : str - The name of the resulting function (if applicable). - simple_mode : bool - Whether to output only a simple, compact statement. - - Returns - ------- - IRModule - The lowered IRModule. - """ - if isinstance(inp, IRModule): - return lower_module(inp, simple_mode) - if isinstance(inp, PrimFunc): - return lower_primfunc(inp, name, simple_mode) - raise ValueError(f"Expected input to be IRModule or PrimFunc, but got {type(inp)}") - - -def check_and_update_host_consistency(targets: dict, host): - """ - Check and update the host field of the given legacy heterogeneous targets - for legacy target API compatibility. - - Parameters - ---------- - targets : dict - Dictionary mapping Target objects to IRModule objects. - host : Target - The target host to be updated. - """ - for tgt in list(targets): - if getattr(tgt, "host", None) is None: - tgt.host = host - - -def mixed_module_pass_manager(target: Target) -> tvm.ir.transform.Sequential: - """ - Constructs a Sequential transformation pass pipeline for a mixed module. - - Parameters - ---------- - target : Target - The target device for which the module is intended. - - Returns - ------- - tvm.ir.transform.Sequential - A sequential pass pipeline for the mixed module. - """ - pass_ctx = tvm.ir.transform.PassContext.current() - mixed_pass_list = [ - # Bind the target first so that target-specific attributes are available. - tir.transform.BindTarget(target), - tir.transform.FP8ComputeLegalize(), - # VerifyVTCMLimit must occur before LowerVtcmAlloc. - tir.transform.VerifyVTCMLimit(target), - tir.transform.LowerVtcmAlloc(), - tir.transform.VerifyMemory(), - tir.transform.AnnotateEntryFunc(), - ] - if pass_ctx.config.get("tir.detect_global_barrier", False): - mixed_pass_list.append(tir.transform.ThreadSync("global")) - mixed_pass_list.extend( - [ - tir.transform.ThreadSync("shared"), - tir.transform.ThreadSync("shared.dyn"), - tir.transform.ThreadSync("warp"), - tir.transform.InferFragment(), - tir.transform.LowerThreadAllreduce(), - ] - ) - if pass_ctx.config.get("tir.use_async_copy", False): - mixed_pass_list.append(tir.transform.InjectPTXAsyncCopy()) - if pass_ctx.config.get("tir.ptx_ldg32", False): - mixed_pass_list.append(tir.transform.InjectPTXLDG32()) - mixed_pass_list.extend( - [ - tir.transform.AnnotateDeviceRegions(), - tir.transform.SplitHostDevice(), - # MergeSharedMemoryAllocations must follow SplitHostDevice. - tir.transform.MergeSharedMemoryAllocations(), - tir.transform.MakePackedAPI(), - tir.transform.FP8StorageLegalize(), - tir.transform.BF16StorageLegalize(), - tir.transform.LowerDeviceKernelLaunch(), - ] - ) - return tvm.ir.transform.Sequential(mixed_pass_list) - - -class CallConv(enum.IntEnum): - """ - Enum representing different calling conventions. - Corresponds to the C++ tvm::ir::CallingConv enum. - """ - - kDefault = 0 - kCPackedFunc = 1 - kDeviceKernelLaunch = 2 - - -def host_module_pass_manager(target_host: Target) -> tvm.ir.transform.Sequential: - """ - Build a sequential pass pipeline for lowering the host part of a mixed module. - - Parameters - ---------- - target_host : Target - The host target for which to lower the module. - - Returns - ------- - tvm.ir.transform.Sequential - A sequential pass pipeline for host-specific transformations. - """ - host_pass_list = [ - # Filter out device kernel launches. - tir.transform.Filter( - lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) - != int(CallConv.kDeviceKernelLaunch) - ), - tir.transform.BindTarget(target_host), - tir.transform.LowerTVMBuiltin(), - tir.transform.LowerCustomDatatypes(), - tir.transform.LowerIntrin(), - tir.transform.LowerDeviceStorageAccessInfo(), - tir.transform.CombineContextCall(), - ] - return tvm.ir.transform.Sequential(host_pass_list) - - -def device_module_pass_manager(target: Target) -> tvm.ir.transform.Sequential: - """ - Build a sequential pass pipeline for lowering the device part of a mixed module. - - Parameters - ---------- - target : Target - The target for device-specific transformations. - - Returns - ------- - tvm.ir.transform.Sequential - A sequential pass pipeline for device-specific transformations. - """ - device_pass_list = [ - # Select only device kernel launches. - tir.transform.Filter( - lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) - == int(CallConv.kDeviceKernelLaunch) - ), - tir.transform.BindTarget(target), - tir.transform.LowerWarpMemory(), - tir.transform.Simplify(), - tir.transform.LowerCustomDatatypes(), - tir.transform.LowerDeviceStorageAccessInfo(), - tir.transform.LowerIntrin(), - ] - return tvm.ir.transform.Sequential(device_pass_list) - - -def split_mixed_module( - mod_mixed: IRModule, target_arg: Target, target_host_arg: Target -) -> Tuple[IRModule, IRModule]: - """ - Split a mixed module containing both device and host parts into separate modules, - applying appropriate transformations on each. - - Parameters - ---------- - mod_mixed : IRModule - The input module containing both device and host code. - target_arg : Target - The target for device-specific transformations. - target_host_arg : Target - The host target for lowering. - - Returns - ------- - Tuple[IRModule, IRModule] - (host module, device module) - """ - target, target_host = target_arg, target_host_arg - if getattr(target, "host", None) is None: - target.host = target_host - if mod_mixed is None: - raise ValueError("Module must be defined") - - mod_mixed = mixed_module_pass_manager(target)(mod_mixed) - host_mod = host_module_pass_manager(target_host)(mod_mixed) - device_mod = device_module_pass_manager(target)(mod_mixed) - - # Warn if target is GPU but no device code was generated. - if "gpu" in target.keys and len(device_mod.functions) == 0: - print( - f"Warning: Specified target {target} but cannot find device code. " - "Did you forget to bind?" - ) - - return host_mod, device_mod - - -def default_target_host(target: Target) -> Target: - """ - Determine the default target host for a given target. - """ - if target is not None and target.device_type == Device.kDLCPU: - return target - # In practice, llvm_enabled should be determined dynamically. - llvm_enabled = True - return Target("llvm") if llvm_enabled else Target("stackvm") - - def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module: """ Build a runtime module from an IRModule and a Target. @@ -427,14 +56,18 @@ def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module: return bf(mod, target) -def tir_to_runtime(inputs: Dict[Target, Tuple[IRModule, IRModule]], target_host: Target): +def tir_to_runtime(host_mod: IRModule, device_mod: IRModule, target, target_host: Target): """ Convert a collection of TIR IRModules (keyed by Target) into a single runtime Module. Parameters ---------- - inputs : dict - Mapping from Target to Tuple[IRModule, IRModule]. + host_mod : IRModule + The host module. + device_mod : IRModule + The device module. + target : Target + The target. target_host : Target The initial host target. @@ -446,19 +79,17 @@ def tir_to_runtime(inputs: Dict[Target, Tuple[IRModule, IRModule]], target_host: # Get the first module to get the attributes # necessary for tests/python/codegen/test_target_codegen_blob.py::test_cuda_multi_lib - first_module = next(iter(inputs.values()))[0] - mhost_all = ir.IRModule({}, attrs=first_module.attrs) + mhost_all = ir.IRModule({}, attrs=host_mod.attrs) device_modules = [] - for tgt, (host_mod, device_mod) in inputs.items(): - overrides_host_target = tgt.get_target_device_type() == target_host.get_target_device_type() - non_host_target_kind = tgt.kind != target_host.kind - if overrides_host_target and non_host_target_kind: - device_modules.append(codegen_build(host_mod, tgt)) - else: - mhost_all.update(host_mod) + overrides_host_target = target.get_target_device_type() == target_host.get_target_device_type() + non_host_target_kind = target.kind != target_host.kind + if overrides_host_target and non_host_target_kind: + device_modules.append(codegen_build(host_mod, target)) + else: + mhost_all.update(host_mod) if len(device_mod.functions) != 0: - device_modules.append(codegen_build(device_mod, tgt)) + device_modules.append(codegen_build(device_mod, target)) mhost = codegen_build(mhost_all, target_host) for dev_mod in device_modules: @@ -468,22 +99,21 @@ def tir_to_runtime(inputs: Dict[Target, Tuple[IRModule, IRModule]], target_host: def build( - inputs: Union[PrimFunc, IRModule], + mod: Union[PrimFunc, IRModule], target: Optional[Union[str, Target]] = None, - name: str = "main", + pipeline: Union[None, str, tvm.transform.Pass] = "default_tir", ): - """ - Build a function with a signature, generating code for devices + """Build a function with a signature, generating code for devices coupled with target information. Parameters ---------- - inputs : Union[PrimFunc, IRModule] + mod : Union[PrimFunc, IRModule] The input to be built. target : Optional[Union[str, Target]] The target for compilation. - name : str - The name of the result function. + pipeline : Union[None, str, tvm.transform.Pass] + The pipeline to use for compilation. Returns ------- @@ -491,60 +121,48 @@ def build( A module combining both host and device code. """ # Convert PrimFunc to IRModule - pass_ctx = tvm.ir.transform.PassContext.current() - if isinstance(inputs, PrimFunc): - f = inputs.with_attr("global_symbol", name) - if pass_ctx.config.get("tir.noalias", True): - f = f.with_attr("tir.noalias", True) - input_mod = tvm.ir.IRModule({tvm.ir.GlobalVar(name): f}) - elif isinstance(inputs, tvm.IRModule): - input_mod = inputs + if isinstance(mod, PrimFunc): + mod = tvm.IRModule.from_expr(mod) else: - raise ValueError("Inputs must be IRModule or PrimFunc") + assert isinstance(mod, tvm.IRModule) - # Get target and target_host + # Step 0: Determine the target in environment target = Target.current() if target is None else target - if target is None and isinstance(input_mod, tvm.IRModule): - target_mod = {} - for gvar, func in input_mod.functions.items(): - tgt = func.attrs.get("target", "llvm") - target_mod.setdefault(tgt, {})[gvar] = func - target_input_mod = { - tgt: tvm.IRModule(funcs).with_attrs(input_mod.attrs) - for tgt, funcs in target_mod.items() - } + if target is None: + target = "llvm" + assert target is not None + target = Target.canon_target(target) + + # Step 1: Determine the host + target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" + if target is not None: + if target.host is not None: + target_host = target.host + elif ndarray.device(target.kind.name, 0).device_type == ndarray.cpu(0).device_type: + target_host = target else: - target_input_mod = {target: input_mod} - - annotated_mods = {} - for tgt, mod in target_input_mod.items(): - if not isinstance(tgt, (str, Target)): - raise ValueError("The key of inputs must be str or Target.") - if not isinstance(mod, tvm.IRModule): - raise ValueError("inputs must be IRModule, or dict of str to IRModule.") - annotated_mods[tgt] = mod - - annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods) - if not target_host: - for tar, mod in annotated_mods.items(): - if ndarray.device(tar.kind.name, 0).device_type == ndarray.cpu(0).device_type: - target_host = tar - break - if not target_host: - target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - - assert annotated_mods is not None and target_host is not None - check_and_update_host_consistency(annotated_mods, target_host) - - # Lower the module - for tgt, mod in annotated_mods.items(): - mod = lower_module(mod, simple_mode=False) - host_mod, device_mod = split_mixed_module(mod, tgt, target_host) - annotated_mods[tgt] = (host_mod, device_mod) + for func in mod.functions.values(): + f_target = func.attrs.get("target", None) + if f_target is not None and f_target.host is not None: + target_host = f_target.host + assert target_host is not None + target_host = Target.canon_target(target_host) + target = target.with_host(target_host) + + # Step 2: Bind the target to the input module + mod = tvm.tir.transform.BindTarget(target)(mod) + # Step 3: Apply the pipeline + if pipeline is not None: + if isinstance(pipeline, str): + pipeline = tvm.tir.get_pipeline(pipeline) + mod = pipeline(mod) + + # Step 4: Finalize the host and device modules + host_mod = tvm.tir.pipeline.finalize_host_passes()(mod) + device_mod = tvm.tir.pipeline.finalize_device_passes()(mod) # Convert TIR IRModules to runtime Module by calling target.build - return tir_to_runtime(annotated_mods, target_host) + return tir_to_runtime(host_mod, device_mod, target, target_host) tvm.register_func("tir.build", build) diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py new file mode 100644 index 000000000000..56a5ec784fa1 --- /dev/null +++ b/python/tvm/tir/pipeline.py @@ -0,0 +1,197 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name +"""The TIR backend compilation pipeline.""" + +import enum +import tvm +from tvm import tir + + +def default_tir_pipeline(): + """The default tir pipeline used in tvm.tir.build""" + + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + """The default lowering passes for TIR backend.""" + pass_ctx = tvm.transform.PassContext.current() + config = pass_ctx.config + passes = [ + tir.transform.InjectPrefetch(), + tir.transform.TextureFlatten(), + tir.transform.StorageFlatten( + 64, bool(config.get("tir.instrument_bound_checkers", False)) + ), + tir.transform.LowerCrossThreadReduction(), + tir.transform.LowerInitBlock(), + tir.transform.PlanAndUpdateBufferAllocationLocation(), + tir.transform.ConvertBlocksToOpaque(), + tir.transform.LiftThreadBinding(), + tir.transform.ManifestSharedMemoryLocalStage(), + tir.transform.CompactBufferAllocation(), + tir.transform.LowerAutoCopy(), + tir.transform.UnifyThreadBinding(), + tir.transform.LowerMatchBuffer(), + tir.transform.Simplify(), + tir.transform.InjectPermutedLayout(), + tir.transform.InjectSoftwarePipeline(), + tir.transform.TransformMmaBufferLayout(), + tir.transform.LowerOpaqueBlock(), + tir.transform.FlattenBuffer(), + tir.transform.BF16ComputeLegalize(), + tir.transform.NarrowDataType(32), + tir.transform.LoopPartition(), + tir.transform.VectorizeLoop(not bool(config.get("tir.disable_vectorize", False))), + tir.transform.InjectVirtualThread(), + tir.transform.InjectDoubleBuffer(), + ] + if not bool(config.get("tir.disable_storage_rewrite", False)): + passes.append(tir.transform.StorageRewrite()) + if config.get("tir.use_async_copy", False): + passes.append(tir.transform.LowerAsyncDMA()) + passes.extend( + [ + tir.transform.HoistIfThenElse(), + tir.transform.UnrollLoop(), + tir.transform.RenormalizeSplitPattern(), + tir.transform.Simplify(), + tir.transform.RemoveNoOp(), + tir.transform.RewriteUnsafeSelect(), + ] + ) + # Additional passes based on configuration. + if bool(config.get("tir.instrument_bound_checkers", False)): + passes.append(tir.transform.InstrumentBoundCheckers()) + if bool(config.get("tir.ptx_ldg32", False)): + passes.append(tir.transform.InjectPTXLDG32(True)) + passes.append( + tir.transform.CommonSubexprElimTIR( + not bool(config.get("tir.disable_cse_tir", False)), + bool(config.get("tir.enable_equiv_terms_in_cse_tir", False)), + ) + ) + if bool(config.get("tir.instrument_lwp", False)): + passes.append(tir.transform.InstrumentProfileIntrinsics()) + passes.extend( + [ + # Bind the target first so that target-specific attributes are available. + tir.transform.FP8ComputeLegalize(), + # VerifyVTCMLimit must occur before LowerVtcmAlloc. + tir.transform.VerifyVTCMLimit(), + tir.transform.LowerVtcmAlloc(), + tir.transform.VerifyMemory(), + tir.transform.AnnotateEntryFunc(), + ] + ) + if bool(config.get("tir.detect_global_barrier", False)): + passes.append(tir.transform.ThreadSync("global")) + passes.extend( + [ + tir.transform.ThreadSync("shared"), + tir.transform.ThreadSync("shared.dyn"), + tir.transform.ThreadSync("warp"), + tir.transform.InferFragment(), + tir.transform.LowerThreadAllreduce(), + ] + ) + if bool(config.get("tir.use_async_copy", False)): + passes.append(tir.transform.InjectPTXAsyncCopy()) + if bool(config.get("tir.ptx_ldg32", False)): + passes.append(tir.transform.InjectPTXLDG32()) + passes.extend( + [ + tir.transform.AnnotateDeviceRegions(), + tir.transform.SplitHostDevice(), + # MergeSharedMemoryAllocations must follow SplitHostDevice. + tir.transform.MergeSharedMemoryAllocations(), + tir.transform.MakePackedAPI(), + tir.transform.FP8StorageLegalize(), + tir.transform.BF16StorageLegalize(), + tir.transform.LowerDeviceKernelLaunch(), + ] + ) + mod = tvm.ir.transform.Sequential(passes)(mod) + return mod + + return _pipeline + + +class CallConv(enum.IntEnum): + """ + Enum representing different calling conventions. + Corresponds to the C++ tvm::ir::CallingConv enum. + """ + + kDefault = 0 + kCPackedFunc = 1 + kDeviceKernelLaunch = 2 + + +def finalize_host_passes(): # pylint: disable=unused-argument + """The default finalization passes for TIR backend.""" + host_pass_list = [ + # Filter out device kernel launches. + tir.transform.Filter( + lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) + != int(CallConv.kDeviceKernelLaunch) + ), + tir.transform.LowerTVMBuiltin(), + tir.transform.LowerCustomDatatypes(), + tir.transform.LowerIntrin(), + tir.transform.LowerDeviceStorageAccessInfo(), + tir.transform.CombineContextCall(), + ] + return tvm.ir.transform.Sequential(host_pass_list) + + +def finalize_device_passes(): # pylint: disable=unused-argument + """The default finalization passes for TIR backend.""" + device_pass_list = [ + # Select only device kernel launches. + tir.transform.Filter( + lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) + == int(CallConv.kDeviceKernelLaunch) + ), + tir.transform.LowerWarpMemory(), + tir.transform.Simplify(), + tir.transform.LowerCustomDatatypes(), + tir.transform.LowerDeviceStorageAccessInfo(), + tir.transform.LowerIntrin(), + ] + return tvm.ir.transform.Sequential(device_pass_list) + + +# global map of pre-built pipelines +PIPELINE_MAP = { + "default_tir": default_tir_pipeline, +} + + +def get_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass: + """Get pre-build pipeline by name + + Parameters + ---------- + name : Optional[str] + Name of the pipeline + """ + if name not in PIPELINE_MAP: + raise ValueError( + f"Unknown pre-built pipeline {name}," f"candidates are {list(PIPELINE_MAP.keys())}" + ) + return PIPELINE_MAP[name](**kwargs) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index fb5a1ba79669..99a2e1e66485 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -713,7 +713,7 @@ def VerifyMemory(): return _ffi_api.VerifyMemory() # type: ignore -def VerifyVTCMLimit(limit: int): +def VerifyVTCMLimit(limit=None): """Verify if the size of the allocated vtcm memory satisfies the limit. Returns diff --git a/src/ir/transform.cc b/src/ir/transform.cc index c65dda7d597a..0ed80310eb97 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -46,25 +46,6 @@ using tvm::runtime::TVMArgs; using tvm::runtime::TVMRetValue; TVM_REGISTER_PASS_CONFIG_OPTION("testing.immutable_module", Bool); -// Register build pipeline related options -TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); struct PassContextThreadLocalEntry { /*! \brief The default pass context. */ diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index fbc43a00cad7..1c77219d453e 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -29,6 +29,26 @@ namespace tvm { namespace tir { namespace transform { +// Register build pipeline related options +TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); + /*! * \brief Function level pass that applies transformations to all * TIR functions within the module. diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 7910cd372ffc..92358c4a4b06 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -761,43 +761,6 @@ def check_llvm_ir(): check_llvm_ir() -@tvm.testing.requires_llvm -def test_llvm_shuffle(): - a = te.placeholder((8,), "int32") - b = te.placeholder((8,), "int32") - c = te.compute((8,), lambda x: a[x] + b[7 - x]) - - # Convert to TIR and create schedule - mod = te.create_prim_func([a, b, c]) - sch = tir.Schedule(mod) - - def my_vectorize(): - def vectorizer(op): - store = op.body - idx = tvm.tir.Ramp(tvm.tir.const(0, "int32"), tvm.tir.const(1, "int32"), 8) - value = store.value - b_idx = tvm.tir.Shuffle([idx], [tvm.tir.const(i, "int32") for i in range(7, -1, -1)]) - new_a = tvm.tir.BufferLoad(value.a.buffer, [idx]) - new_b = tvm.tir.BufferLoad(value.b.buffer, [b_idx]) - value = new_a + new_b - return tvm.tir.BufferStore(store.buffer, new_a + new_b, [idx]) - - def _transform(f, *_): - return f.with_body( - tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ["tir.For"]) - ) - - return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize") - - with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, my_vectorize())]}): - module = tvm.build(sch.mod) - a_ = tvm.nd.array(np.arange(1, 9, dtype="int32")) - b_ = tvm.nd.array(np.arange(8, 0, -1, dtype="int32")) - c_ = tvm.nd.array(np.zeros((8,), dtype="int32")) - module(a_, b_, c_) - tvm.testing.assert_allclose(c_.numpy(), (a_.numpy() * 2).astype("int32")) - - def np_float2np_bf16(arr): """Convert a numpy array of float to a numpy array of bf16 in uint16""" From 9ee17ae8f92574a262b6529c2d61b7eb1dc1c11f Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 20 Feb 2025 02:51:32 +0000 Subject: [PATCH 11/17] lint --- .../tir-transform/test_tir_transform_narrow_datatype.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py index e5cb3667633c..a7b528093967 100644 --- a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py +++ b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py @@ -154,7 +154,9 @@ def check(m, n, target_bits, target_dtype): # The maximum index is (2**15 * 2**15 - 1) * 2 <= 2**31 - 1 check(const(2**15, "int64"), const(2**15, "int64"), target_bits=32, target_dtype="int32") # The maximum index is (2**15 * 2**15 - 1 + 2**15) * 2 > 2**31 - 1 - check(const(2**15, "int64"), const((2**15 + 1), "int64"), target_bits=32, target_dtype="int64") + check( + const(2**15, "int64"), const((2**15 + 1), "int64"), target_bits=32, target_dtype="int64" + ) def test_condition(): From c644f8fb9313391dfbbd57c2d76b77091a688085 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 20 Feb 2025 03:01:17 +0000 Subject: [PATCH 12/17] lint --- python/tvm/tir/build.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py index 516eb5d1d114..64d667a21ed8 100644 --- a/python/tvm/tir/build.py +++ b/python/tvm/tir/build.py @@ -17,15 +17,13 @@ # pylint: disable=invalid-name """The build utils in python.""" -from typing import Union, Optional, Dict, Tuple -import enum +from typing import Union, Optional import tvm -from tvm import tir, ir +from tvm import ir from tvm.runtime import ndarray from tvm.tir import PrimFunc from tvm.ir.module import IRModule from tvm.target import Target -from tvm._ffi.runtime_ctypes import Device def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module: From bf8ed67cb6d5871ea48463a77497a47b85e69970 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 20 Feb 2025 04:14:56 +0000 Subject: [PATCH 13/17] fix --- tests/python/codegen/test_target_codegen_llvm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index 92358c4a4b06..304c79559cbb 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -42,7 +42,7 @@ def test_llvm_intrin(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "prefetch")) - fcode = tvm.build(mod, None, "llvm") + fcode = tvm.build(mod, None) @tvm.testing.requires_llvm @@ -54,7 +54,7 @@ def test_llvm_void_intrin(): ib.emit(x) body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main")) - fcode = tvm.build(mod, None, "llvm") + fcode = tvm.build(mod, None) @tvm.testing.requires_llvm @@ -106,7 +106,7 @@ def test_llvm_lookup_intrin(): ib.emit(x) body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main")) - fcode = tvm.build(mod, None, "llvm") + fcode = tvm.build(mod, None) @tvm.testing.requires_llvm From 9c79bb07d82c42fe24fdd158ee4f1234f3e8e909 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 20 Feb 2025 04:58:55 +0000 Subject: [PATCH 14/17] fix --- docs/how_to/tutorials/cross_compilation_and_rpc.py | 2 +- src/relax/transform/fold_constant.cc | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/how_to/tutorials/cross_compilation_and_rpc.py b/docs/how_to/tutorials/cross_compilation_and_rpc.py index 81c73fd051ef..94a6f48b4b73 100644 --- a/docs/how_to/tutorials/cross_compilation_and_rpc.py +++ b/docs/how_to/tutorials/cross_compilation_and_rpc.py @@ -119,7 +119,7 @@ else: target = "llvm -mtriple=armv7l-linux-gnueabihf" -func = tvm.build(mod, target=target, name="add_one") +func = tvm.build(mod, target=target) # save the lib at a local temp folder temp = utils.tempdir() path = temp.relpath("lib.tar") diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index 91a1e806cefc..fb6a01a19d7f 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -117,7 +117,8 @@ class ConstantFolder : public ExprMutator { // TODO(Hongyi): further check and narrow the scope of foldable function auto* pf = runtime::Registry::Get("tir.build"); ICHECK(pf != nullptr) << "Cannot find tir.build in registry"; - runtime::Module rt_module = (*pf)(func, eval_cpu_target, "tir_function"); + func = WithAttr(func, tvm::attr::kGlobalSymbol, String("tir_function")); + runtime::Module rt_module = (*pf)(func, eval_cpu_target); build_func = rt_module.GetFunction("tir_function"); } catch (const tvm::Error& err) { // build failure may happen in which case we skip From 806a45012d515afde50884b569a1d04c2ed9896a Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 20 Feb 2025 05:21:12 +0000 Subject: [PATCH 15/17] fix --- tests/python/contrib/test_hexagon/test_meta_schedule.py | 2 +- tests/python/contrib/test_hexagon/test_sigmoid.py | 2 +- .../test_tir_schedule_tensorize_ldmatrix_mma_numeric.py | 2 +- .../tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py | 2 +- .../tir-transform/test_tir_transform_lower_tvm_builtin.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py b/tests/python/contrib/test_hexagon/test_meta_schedule.py index 26acedb88e21..c0c7355a9afa 100644 --- a/tests/python/contrib/test_hexagon/test_meta_schedule.py +++ b/tests/python/contrib/test_hexagon/test_meta_schedule.py @@ -156,7 +156,7 @@ def schedule_dense(sch, block, m_size, do_tune): def verify_dense(sch, target, m_size, n_size, k_size, hexagon_session): """Verify dense operator.""" - f = tvm.build(sch.mod["main"], target=target, name="dense") + f = tvm.build(sch.mod["main"], target=target) mod = hexagon_session.load_module(f) dev = hexagon_session.device diff --git a/tests/python/contrib/test_hexagon/test_sigmoid.py b/tests/python/contrib/test_hexagon/test_sigmoid.py index cc633795c217..1247d9075972 100644 --- a/tests/python/contrib/test_hexagon/test_sigmoid.py +++ b/tests/python/contrib/test_hexagon/test_sigmoid.py @@ -92,7 +92,7 @@ def test_sigmoid( func_name = "sigmoid" with tvm.transform.PassContext(opt_level=3): - runtime_module = tvm.build(tir_s.mod, target=get_hexagon_target("v69"), name=func_name) + runtime_module = tvm.build(tir_s.mod, target=get_hexagon_target("v69")) assert "hvx_sigmoid" in runtime_module.get_source("asm") assert "vmin" in runtime_module.get_source("asm") diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py index 390745fe9d96..fe9998bc798e 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py @@ -117,7 +117,7 @@ def run_test( mma_store_intrin, ) - f = tvm.build(sch.mod["main"], target="cuda", name="dense") + f = tvm.build(sch.mod["main"], target="cuda") dev = tvm.device("cuda", 0) diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py b/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py index 8077a603bcf2..2b3e6ce39bfb 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py @@ -109,7 +109,7 @@ def run_test( mma_store_intrin, ) - f = tvm.build(sch.mod["main"], target="rocm", name="dense") + f = tvm.build(sch.mod["main"], target="rocm") dev = tvm.device("rocm", 0) if in_dtype == "float32": diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py index 754ce032404d..0a040b0eeadb 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py @@ -179,7 +179,7 @@ def build_tir(): ) mod = build_tir() - f = tvm.build(mod, None, "llvm") + f = tvm.build(mod, None) a = tvm.nd.array(np.zeros(2, dtype="float32")) f(a) tvm.testing.assert_allclose(a.numpy(), expected_value) From 5dee69df7e4e12106144b83f3063e25d515a02fb Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 20 Feb 2025 14:19:51 +0000 Subject: [PATCH 16/17] fix --- python/tvm/tir/build.py | 99 +++++++++++++++++++++----------------- python/tvm/tir/pipeline.py | 21 -------- 2 files changed, 56 insertions(+), 64 deletions(-) diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py index 64d667a21ed8..cd44ed881ba3 100644 --- a/python/tvm/tir/build.py +++ b/python/tvm/tir/build.py @@ -17,7 +17,9 @@ # pylint: disable=invalid-name """The build utils in python.""" -from typing import Union, Optional +from typing import Union, Optional, Dict +import enum + import tvm from tvm import ir from tvm.runtime import ndarray @@ -26,25 +28,49 @@ from tvm.target import Target -def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module: - """ - Build a runtime module from an IRModule and a Target. - - If the "tir.disable_assert" flag is set in the pass context, - the SkipAssert transformation is applied. +def split_host_device_mods(mod): + """Split an IRModule into host and device modules. Parameters ---------- - mod : IRModule - The input IRModule. - target : Target - The target for which to build the module. + mod : tvm.IRModule + The input module to split Returns ------- - tvm.runtime.Module - The built runtime module. + host_mod : tvm.IRModule + The module containing host functions + device_mod_dict : Dict[Target, tvm.IRModule] + A dict mapping targets to device modules """ + + class CallConv(enum.IntEnum): + """Enum representing different calling conventions. + Corresponds to the C++ tvm::ir::CallingConv enum. + """ + + kDefault = 0 + kCPackedFunc = 1 + kDeviceKernelLaunch = 2 + + host_mod = tvm.tir.transform.Filter( + lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) + != int(CallConv.kDeviceKernelLaunch) + )(mod) + device_mod = tvm.tir.transform.Filter( + lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) + == int(CallConv.kDeviceKernelLaunch) + )(mod) + device_mod_dict = {} + for gv, func in device_mod.functions.items(): + device_mod_dict.setdefault(func.attrs.get("target", None), dict()).update({gv: func}) + for target, funcs in device_mod_dict.items(): + device_mod_dict[target] = tvm.IRModule(funcs, attrs=device_mod.attrs) + return host_mod, device_mod_dict + + +def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module: + """Build a runtime module from an IRModule and a Target.""" if tvm.ir.transform.PassContext.current().config.get("tir.disable_assert", False): mod = tvm.tir.transform.SkipAssert()(mod) build_f_name = "target.build." + target.kind.name @@ -54,38 +80,18 @@ def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module: return bf(mod, target) -def tir_to_runtime(host_mod: IRModule, device_mod: IRModule, target, target_host: Target): - """ - Convert a collection of TIR IRModules (keyed by Target) into a single runtime Module. - - Parameters - ---------- - host_mod : IRModule - The host module. - device_mod : IRModule - The device module. - target : Target - The target. - target_host : Target - The initial host target. - - Returns - ------- - tvm.runtime.Module - The final runtime module. - """ +def tir_to_runtime( + host_mod: IRModule, device_mod_dict: Dict[Target, IRModule], target_host: Target +): + """Convert a collection of TIR IRModules (keyed by Target) into a single runtime Module.""" # Get the first module to get the attributes # necessary for tests/python/codegen/test_target_codegen_blob.py::test_cuda_multi_lib mhost_all = ir.IRModule({}, attrs=host_mod.attrs) + mhost_all.update(host_mod) device_modules = [] - overrides_host_target = target.get_target_device_type() == target_host.get_target_device_type() - non_host_target_kind = target.kind != target_host.kind - if overrides_host_target and non_host_target_kind: - device_modules.append(codegen_build(host_mod, target)) - else: - mhost_all.update(host_mod) + for target, device_mod in device_mod_dict.items(): if len(device_mod.functions) != 0: device_modules.append(codegen_build(device_mod, target)) @@ -149,18 +155,25 @@ def build( # Step 2: Bind the target to the input module mod = tvm.tir.transform.BindTarget(target)(mod) + # Step 3: Apply the pipeline if pipeline is not None: if isinstance(pipeline, str): pipeline = tvm.tir.get_pipeline(pipeline) mod = pipeline(mod) - # Step 4: Finalize the host and device modules - host_mod = tvm.tir.pipeline.finalize_host_passes()(mod) - device_mod = tvm.tir.pipeline.finalize_device_passes()(mod) + # Step 4: Get host and device modules + host_mod, device_mod_dict = split_host_device_mods(mod) + + # Step 5: Apply finalization passes + host_mod = tvm.tir.pipeline.finalize_host_passes()(host_mod) + device_mod_dict = { + target: tvm.tir.pipeline.finalize_device_passes()(device_mod) + for target, device_mod in device_mod_dict.items() + } # Convert TIR IRModules to runtime Module by calling target.build - return tir_to_runtime(host_mod, device_mod, target, target_host) + return tir_to_runtime(host_mod, device_mod_dict, target_host) tvm.register_func("tir.build", build) diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py index 56a5ec784fa1..ef6d25c4e9e2 100644 --- a/python/tvm/tir/pipeline.py +++ b/python/tvm/tir/pipeline.py @@ -131,25 +131,9 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I return _pipeline -class CallConv(enum.IntEnum): - """ - Enum representing different calling conventions. - Corresponds to the C++ tvm::ir::CallingConv enum. - """ - - kDefault = 0 - kCPackedFunc = 1 - kDeviceKernelLaunch = 2 - - def finalize_host_passes(): # pylint: disable=unused-argument """The default finalization passes for TIR backend.""" host_pass_list = [ - # Filter out device kernel launches. - tir.transform.Filter( - lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) - != int(CallConv.kDeviceKernelLaunch) - ), tir.transform.LowerTVMBuiltin(), tir.transform.LowerCustomDatatypes(), tir.transform.LowerIntrin(), @@ -162,11 +146,6 @@ def finalize_host_passes(): # pylint: disable=unused-argument def finalize_device_passes(): # pylint: disable=unused-argument """The default finalization passes for TIR backend.""" device_pass_list = [ - # Select only device kernel launches. - tir.transform.Filter( - lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) - == int(CallConv.kDeviceKernelLaunch) - ), tir.transform.LowerWarpMemory(), tir.transform.Simplify(), tir.transform.LowerCustomDatatypes(), From 33bf9dc26acc613a8970870fdf2b4ca939642706 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 20 Feb 2025 18:29:32 +0000 Subject: [PATCH 17/17] fix --- python/tvm/tir/pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py index ef6d25c4e9e2..0b6d622c90e1 100644 --- a/python/tvm/tir/pipeline.py +++ b/python/tvm/tir/pipeline.py @@ -18,7 +18,6 @@ # pylint: disable=invalid-name """The TIR backend compilation pipeline.""" -import enum import tvm from tvm import tir