From c6569597a9303302d0ed300eb4b6a91c754e316f Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Fri, 7 Mar 2025 20:58:39 +0000 Subject: [PATCH] Fix the get_target_compute_version for sm >= 100 --- python/tvm/contrib/nvcc.py | 12 +++++++----- python/tvm/relax/vm_build.py | 7 +++++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/tvm/contrib/nvcc.py b/python/tvm/contrib/nvcc.py index d12ddf883cf4..c8b749b36bf1 100644 --- a/python/tvm/contrib/nvcc.py +++ b/python/tvm/contrib/nvcc.py @@ -292,13 +292,15 @@ def get_target_compute_version(target=None): target = target or Target.current() if target and target.arch: arch = target.arch.split("_")[1] - if len(arch) == 2: - major, minor = arch - return major + "." + minor - elif len(arch) == 3: + if len(arch) < 2: + raise ValueError(f"The arch is not expected {target.arch}") + if arch[-1].isalpha(): # This is for arch like "sm_90a" - major, minor, suffix = arch + suffix = arch[-1] + major = arch[:-2] + minor = arch[-2] return major + "." + minor + "." + suffix + return arch[:-1] + "." + arch[-1] # 3. GPU compute version if tvm.cuda(0).exist: diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index f44fcb9c226c..564c4a747d48 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -197,8 +197,11 @@ def build( params: Optional[Dict[str, list]] Parameters for the input IRModule that will be bound. - pipeline : str = "default_build" - The compilation pipeline to use. + relax_pipeline : str = "default" + The Relax compilation pipeline to use. + + tir_pipelinie : str = "default" + The TIR compilation pipeline to use. exec_mode: {"bytecode", "compiled"} The execution mode.