From 76eadba86e7b9edb142e766eb4e03f83ec5e877a Mon Sep 17 00:00:00 2001 From: zheliuyu <15750543867@163.com> Date: Sat, 13 Sep 2025 11:32:27 +0800 Subject: [PATCH 01/10] add npu support --- pytest.ini | 1 + src/kernels/utils.py | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/pytest.ini b/pytest.ini index d8fc63c2..608f664a 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,3 +4,4 @@ markers = rocm_only: marks tests that should only run on hosts with ROCm GPUs darwin_only: marks tests that should only run on macOS xpu_only: marks tests that should only run on hosts with Intel XPUs + npu_only: marks tests that should only run on Ascend NPUs diff --git a/src/kernels/utils.py b/src/kernels/utils.py index c956f4f7..70e091f3 100644 --- a/src/kernels/utils.py +++ b/src/kernels/utils.py @@ -49,9 +49,14 @@ def build_variant() -> str: elif torch.version.xpu is not None: version = torch.version.xpu compute_framework = f"xpu{version[0:4]}{version[5:6]}" + elif torch.npu.is_available(): + import torch_npu + torch_npu_version = parse(torch_npu.__version__[:5]) + compute_framework = "cann" + else: raise AssertionError( - "Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled." + "Torch was not compiled with CUDA, Metal, XPU, NPU, or ROCm enabled." ) torch_version = parse(torch.__version__) @@ -61,6 +66,9 @@ def build_variant() -> str: if os == "darwin": cpu = "aarch64" if cpu == "arm64" else cpu return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}" + + if torch.npu.is_available(): + return f"torch{torch_version}-torch_npu{torch_npu_version}-{compute_framework}-{cpu}-{os}" cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" From 3e4bf0d0f96c2ef861c266cd7af6b7d0801126a7 Mon Sep 17 00:00:00 2001 From: zheliuyu <15750543867@163.com> Date: Sun, 14 Sep 2025 11:41:58 +0800 Subject: [PATCH 02/10] fix return build file's name --- src/kernels/utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/kernels/utils.py b/src/kernels/utils.py index 70e091f3..65092df9 100644 --- a/src/kernels/utils.py +++ b/src/kernels/utils.py @@ -51,9 +51,9 @@ def build_variant() -> str: compute_framework = f"xpu{version[0:4]}{version[5:6]}" elif torch.npu.is_available(): import torch_npu - torch_npu_version = parse(torch_npu.__version__[:5]) + torch_npu_version = parse(torch_npu.__version__) + tn_major, tn_minor = torch_npu_version.major, torch_npu_version.minor compute_framework = "cann" - else: raise AssertionError( "Torch was not compiled with CUDA, Metal, XPU, NPU, or ROCm enabled." @@ -68,7 +68,8 @@ def build_variant() -> str: return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}" if torch.npu.is_available(): - return f"torch{torch_version}-torch_npu{torch_npu_version}-{compute_framework}-{cpu}-{os}" + t_major, t_minor = torch_version.major, torch_version.minor + return f"torch{t_major}{t_minor}-torch_npu{tn_major}{tn_minor}-{compute_framework}-{cpu}-{os}" cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" From 6f73ee723ba9e89ace594a20872cf5225b030036 Mon Sep 17 00:00:00 2001 From: zheliuyu <15750543867@163.com> Date: Sun, 14 Sep 2025 14:54:17 +0800 Subject: [PATCH 03/10] fix --- src/kernels/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/kernels/utils.py b/src/kernels/utils.py index 65092df9..9c345c81 100644 --- a/src/kernels/utils.py +++ b/src/kernels/utils.py @@ -49,7 +49,7 @@ def build_variant() -> str: elif torch.version.xpu is not None: version = torch.version.xpu compute_framework = f"xpu{version[0:4]}{version[5:6]}" - elif torch.npu.is_available(): + elif torch._C._get_privateuse1_backend_name() == "npu": import torch_npu torch_npu_version = parse(torch_npu.__version__) tn_major, tn_minor = torch_npu_version.major, torch_npu_version.minor @@ -67,7 +67,7 @@ def build_variant() -> str: cpu = "aarch64" if cpu == "arm64" else cpu return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}" - if torch.npu.is_available(): + if torch._C._get_privateuse1_backend_name() == "npu": t_major, t_minor = torch_version.major, torch_version.minor return f"torch{t_major}{t_minor}-torch_npu{tn_major}{tn_minor}-{compute_framework}-{cpu}-{os}" From ec9d30ad22a65804241592de7f9971802fdfb0fa Mon Sep 17 00:00:00 2001 From: zheliuyu <15750543867@163.com> Date: Mon, 15 Sep 2025 17:17:36 +0800 Subject: [PATCH 04/10] Add support for npu layer repostories --- src/kernels/layer.py | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/src/kernels/layer.py b/src/kernels/layer.py index 8eedb3a9..7a8775d7 100644 --- a/src/kernels/layer.py +++ b/src/kernels/layer.py @@ -87,7 +87,7 @@ class Device: Args: type (`str`): - The device type (e.g., "cuda", "mps", "rocm", "xpu"). + The device type (e.g., "cuda", "mps", "rocm", "xpu", "npu"). properties ([`CUDAProperties`], *optional*): Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices. @@ -109,6 +109,9 @@ class Device: # XPU device (e.g., Intel(R) Data Center GPU Max 1550) xpu_device = Device(type="xpu") + + # NPU device (e.g., Huawei Ascend Atlas A2) + npu_device = Device(type="npu") ``` """ @@ -130,6 +133,8 @@ def create_repo(self) -> _DeviceRepos: return _MPSRepos() elif self.type == "xpu": return _XPURepos() + elif self.type == "npu": + return _NPURepos() else: raise ValueError(f"Unknown device type: {self.type}") @@ -472,6 +477,26 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): self._repos = repos +class _NPURepos(_DeviceRepos): + _repos: Dict[Mode, LayerRepositoryProtocol] + + def __init__(self): + super().__init__() + self._repos = {} + + @property + def repos( + self, + ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: + return self._repos + + def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): + if device.type != "npu": + raise ValueError(f"Device type must be 'npu', got {device.type}") + + self._repos = repos + + class _MPSRepos(_DeviceRepos): _repos: Dict[Mode, LayerRepositoryProtocol] @@ -556,7 +581,7 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): def _validate_device_type(device_type: str) -> None: """Validate that the device type is supported.""" - supported_devices = {"cuda", "rocm", "mps", "xpu"} + supported_devices = {"cuda", "rocm", "mps", "xpu", "npu"} if device_type not in supported_devices: raise ValueError( f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}" @@ -814,7 +839,7 @@ def kernelize( `Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with `torch.compile`. device (`Union[str, torch.device]`, *optional*): - The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu". + The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu", "npu". The device type will be inferred from the model parameters when not provided. use_fallback (`bool`, *optional*, defaults to `True`): Whether to use the original forward method of modules when no compatible kernel could be found. From f8f9f6368efe0eb809610d0750aa5ff7d4efdd7d Mon Sep 17 00:00:00 2001 From: zheliuyu <15750543867@163.com> Date: Tue, 16 Sep 2025 17:25:37 +0800 Subject: [PATCH 05/10] fix comments text --- src/kernels/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernels/layer.py b/src/kernels/layer.py index 7a8775d7..ea460813 100644 --- a/src/kernels/layer.py +++ b/src/kernels/layer.py @@ -863,7 +863,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return F.silu(x[..., :d]) * x[..., d:] mapping = { - "LayerNorm": { + "SiluAndMul": { "cuda": LayerRepository( repo_id="kernels-community/activation", layer_name="SiluAndMul", From 26d8bd3034ebf30c487ce3c2167b0afd5ff401d1 Mon Sep 17 00:00:00 2001 From: zheliuyu <15750543867@163.com> Date: Tue, 16 Sep 2025 20:01:47 +0800 Subject: [PATCH 06/10] add test_layers on npu --- tests/conftest.py | 5 +++++ tests/test_layer.py | 38 +++++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6d9d379d..4a2f3dc4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,9 @@ and torch.version.xpu is not None and torch.xpu.device_count() > 0 ) +has_npu = ( + torch._C._get_privateuse1_backend_name() == "npu" +) def pytest_runtest_setup(item): @@ -29,3 +32,5 @@ def pytest_runtest_setup(item): pytest.skip("skipping macOS-only test on non-macOS platform") if "xpu_only" in item.keywords and not has_xpu: pytest.skip("skipping XPU-only test on host without XPU") + if "npu_only" in item.keywords and not has_npu: + pytest.skip("skipping NPU-only test on host without NPU") diff --git a/tests/test_layer.py b/tests/test_layer.py index 0d17ce1f..75e3fe50 100644 --- a/tests/test_layer.py +++ b/tests/test_layer.py @@ -52,6 +52,12 @@ layer_name="LigerRMSNorm", # Triton ) }, + "SwiGlu": { + "npu": LayerRepository( + repo_id="kernels-ext-npu/SwiGlu", + layer_name="SwiGlu", + ) + }, } register_kernel_mapping(kernel_layer_mapping) @@ -104,6 +110,11 @@ class SiluAndMulStringDevice(SiluAndMul): pass +@use_kernel_forward_from_hub("SwiGlu") +class SiluAndMulNPU(SiluAndMul): + pass + + @use_kernel_forward_from_hub("Linear") class TorchLinearWithCounter(nn.Linear): def __init__(self, *args, **kwargs): @@ -122,8 +133,10 @@ def device(): return "cuda" elif hasattr(torch, "xpu") and torch.xpu.is_available(): return "xpu" + elif torch._C._get_privateuse1_backend_name() == "npu": + return "npu" - pytest.skip("No CUDA or XPU") + pytest.skip("No CUDA, NPU or XPU") def test_arg_kinds(): @@ -204,10 +217,33 @@ def test_hub_forward_xpu(): assert rms_norm_with_kernel.n_calls == 0 +@pytest.mark.npu_only +def test_hub_forward_npu(): + torch.manual_seed(0) + + silu_and_mul = SiluAndMul() + X = torch.randn((32, 64), device="npu") + Y = silu_and_mul(X) + + silu_and_mul_with_kernel = kernelize( + SiluAndMulNPU(), device="npu", mode=Mode.INFERENCE + ) + Y_kernel = silu_and_mul_with_kernel(X) + + torch.testing.assert_close(Y_kernel, Y) + + assert silu_and_mul.n_calls == 1 + assert silu_and_mul_with_kernel.n_calls == 0 + + @pytest.mark.skipif( hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(), reason="Skip on xpu devices", ) +@pytest.mark.skipif( + torch._C._get_privateuse1_backend_name() == "npu", + reason="Skip on npu devices", +) def test_rocm_kernel_mapping(): """Test that ROCm shorthand device mapping works correctly.""" kernel_layer_mapping = { From daf56a54c76ce978ecc7d5d2fd47c70d27c4f16d Mon Sep 17 00:00:00 2001 From: zheliuyu <15750543867@163.com> Date: Wed, 17 Sep 2025 10:37:15 +0800 Subject: [PATCH 07/10] fix npu_kernel_layer_mapping() --- tests/test_layer.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/test_layer.py b/tests/test_layer.py index 358e267b..cbccadd7 100644 --- a/tests/test_layer.py +++ b/tests/test_layer.py @@ -52,12 +52,6 @@ layer_name="LigerRMSNorm", # Triton ) }, - "SwiGlu": { - "npu": LayerRepository( - repo_id="kernels-ext-npu/SwiGlu", - layer_name="SwiGlu", - ) - }, } register_kernel_mapping(kernel_layer_mapping) @@ -219,6 +213,17 @@ def test_hub_forward_xpu(): @pytest.mark.npu_only def test_hub_forward_npu(): + npu_kernel_layer_mapping = { + "SwiGlu": { + "npu": LayerRepository( + repo_id="kernels-ext-npu/SwiGlu", + layer_name="SwiGlu", + ) + } + } + + register_kernel_mapping(npu_kernel_layer_mapping) + torch.manual_seed(0) silu_and_mul = SiluAndMul() From e598c4aa167a36af50667193fc07949e64a0eb34 Mon Sep 17 00:00:00 2001 From: Chunyu <15750543867@163.com> Date: Thu, 18 Sep 2025 10:39:43 +0800 Subject: [PATCH 08/10] update compute_framework string &fix unit-tests on npu MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Daniƫl de Kok --- src/kernels/layer.py | 8 ++++---- src/kernels/utils.py | 20 +++++++++++--------- tests/conftest.py | 4 +++- tests/test_kernel_locking.py | 1 + tests/test_layer.py | 33 ++++++++++++--------------------- 5 files changed, 31 insertions(+), 35 deletions(-) diff --git a/src/kernels/layer.py b/src/kernels/layer.py index 2de88ee2..f3e5265a 100644 --- a/src/kernels/layer.py +++ b/src/kernels/layer.py @@ -87,7 +87,7 @@ class Device: Args: type (`str`): - The device type (e.g., "cuda", "mps", "rocm", "xpu", "npu"). + The device type (e.g., "cuda", "mps", "npu", "rocm", "xpu"). properties ([`CUDAProperties`], *optional*): Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices. @@ -110,7 +110,7 @@ class Device: # XPU device (e.g., Intel(R) Data Center GPU Max 1550) xpu_device = Device(type="xpu") - # NPU device (e.g., Huawei Ascend Atlas A2) + # NPU device (Huawei Ascend) npu_device = Device(type="npu") ``` """ @@ -581,7 +581,7 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): def _validate_device_type(device_type: str) -> None: """Validate that the device type is supported.""" - supported_devices = {"cuda", "rocm", "mps", "xpu", "npu"} + supported_devices = {"cuda", "mps", "npu", "rocm", "xpu"} if device_type not in supported_devices: raise ValueError( f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}" @@ -839,7 +839,7 @@ def kernelize( `Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with `torch.compile`. device (`Union[str, torch.device]`, *optional*): - The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu", "npu". + The device type to load kernels for. Supported device types are: "cuda", "mps", "npu", "rocm", "xpu". The device type will be inferred from the model parameters when not provided. use_fallback (`bool`, *optional*, defaults to `True`): Whether to use the original forward method of modules when no compatible kernel could be found. diff --git a/src/kernels/utils.py b/src/kernels/utils.py index 9c345c81..2bae1c19 100644 --- a/src/kernels/utils.py +++ b/src/kernels/utils.py @@ -35,6 +35,13 @@ def _get_cache_dir() -> Optional[str]: CACHE_DIR: Optional[str] = _get_cache_dir() +def _get_privateuse_backend_name() -> Optional[str]: + import torch + if hasattr(torch._C, "_get_privateuse1_backend_name"): + return torch._C._get_privateuse1_backend_name() + return None + + def build_variant() -> str: import torch @@ -49,11 +56,10 @@ def build_variant() -> str: elif torch.version.xpu is not None: version = torch.version.xpu compute_framework = f"xpu{version[0:4]}{version[5:6]}" - elif torch._C._get_privateuse1_backend_name() == "npu": - import torch_npu - torch_npu_version = parse(torch_npu.__version__) - tn_major, tn_minor = torch_npu_version.major, torch_npu_version.minor - compute_framework = "cann" + elif _get_privateuse_backend_name() == "npu": + from torch_npu.utils.collect_env import get_cann_version + cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2] + compute_framework = f"cann{cann_major}{cann_minor}" else: raise AssertionError( "Torch was not compiled with CUDA, Metal, XPU, NPU, or ROCm enabled." @@ -66,10 +72,6 @@ def build_variant() -> str: if os == "darwin": cpu = "aarch64" if cpu == "arm64" else cpu return f"torch{torch_version.major}{torch_version.minor}-{compute_framework}-{cpu}-{os}" - - if torch._C._get_privateuse1_backend_name() == "npu": - t_major, t_minor = torch_version.major, torch_version.minor - return f"torch{t_major}{t_minor}-torch_npu{tn_major}{tn_minor}-{compute_framework}-{cpu}-{os}" cxxabi = "cxx11" if torch.compiled_with_cxx11_abi() else "cxx98" diff --git a/tests/conftest.py b/tests/conftest.py index ecda1a58..82fcea0c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,8 @@ import pytest import torch +from kernels.utils import _get_privateuse_backend_name + has_cuda = ( hasattr(torch.version, "cuda") and torch.version.cuda is not None @@ -19,7 +21,7 @@ and torch.xpu.device_count() > 0 ) has_npu = ( - torch._C._get_privateuse1_backend_name() == "npu" + _get_privateuse_backend_name() == "npu" ) diff --git a/tests/test_kernel_locking.py b/tests/test_kernel_locking.py index e4691b2a..7daaa889 100644 --- a/tests/test_kernel_locking.py +++ b/tests/test_kernel_locking.py @@ -35,6 +35,7 @@ def test_load_locked(): load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock") +@pytest.mark.cuda_only def test_layer_locked(): project_dir = Path(__file__).parent / "layer_locking" diff --git a/tests/test_layer.py b/tests/test_layer.py index cbccadd7..6d0a8b8e 100644 --- a/tests/test_layer.py +++ b/tests/test_layer.py @@ -21,14 +21,21 @@ _KERNEL_MAPPING, _validate_layer, ) -from kernels.utils import install_kernel +from kernels.utils import ( + _get_privateuse_backend_name, + install_kernel, +) kernel_layer_mapping = { "SiluAndMul": { Device(type="cuda"): LayerRepository( repo_id="kernels-community/activation", layer_name="SiluAndMul", - ) + ), + "npu": LayerRepository( + repo_id="kernels-ext-npu/SwiGlu", + layer_name="SwiGlu", + ), }, "SiluAndMulNoCompile": { "cuda": LayerRepository( @@ -104,11 +111,6 @@ class SiluAndMulStringDevice(SiluAndMul): pass -@use_kernel_forward_from_hub("SwiGlu") -class SiluAndMulNPU(SiluAndMul): - pass - - @use_kernel_forward_from_hub("Linear") class TorchLinearWithCounter(nn.Linear): def __init__(self, *args, **kwargs): @@ -127,7 +129,7 @@ def device(): return "cuda" elif hasattr(torch, "xpu") and torch.xpu.is_available(): return "xpu" - elif torch._C._get_privateuse1_backend_name() == "npu": + elif _get_privateuse_backend_name() == "npu": return "npu" pytest.skip("No CUDA, NPU or XPU") @@ -213,17 +215,6 @@ def test_hub_forward_xpu(): @pytest.mark.npu_only def test_hub_forward_npu(): - npu_kernel_layer_mapping = { - "SwiGlu": { - "npu": LayerRepository( - repo_id="kernels-ext-npu/SwiGlu", - layer_name="SwiGlu", - ) - } - } - - register_kernel_mapping(npu_kernel_layer_mapping) - torch.manual_seed(0) silu_and_mul = SiluAndMul() @@ -231,7 +222,7 @@ def test_hub_forward_npu(): Y = silu_and_mul(X) silu_and_mul_with_kernel = kernelize( - SiluAndMulNPU(), device="npu", mode=Mode.INFERENCE + SiluAndMulWithKernel(), device="npu", mode=Mode.INFERENCE ) Y_kernel = silu_and_mul_with_kernel(X) @@ -246,7 +237,7 @@ def test_hub_forward_npu(): reason="Skip on xpu devices", ) @pytest.mark.skipif( - torch._C._get_privateuse1_backend_name() == "npu", + _get_privateuse_backend_name() == "npu", reason="Skip on npu devices", ) def test_rocm_kernel_mapping(): From 2ad1931ddf81689aa436ccff8a75f574a2654ad8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 23 Sep 2025 07:41:53 +0000 Subject: [PATCH 09/10] black --- src/kernels/utils.py | 2 ++ tests/conftest.py | 4 +--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/kernels/utils.py b/src/kernels/utils.py index 2bae1c19..5b3dc382 100644 --- a/src/kernels/utils.py +++ b/src/kernels/utils.py @@ -37,6 +37,7 @@ def _get_cache_dir() -> Optional[str]: def _get_privateuse_backend_name() -> Optional[str]: import torch + if hasattr(torch._C, "_get_privateuse1_backend_name"): return torch._C._get_privateuse1_backend_name() return None @@ -58,6 +59,7 @@ def build_variant() -> str: compute_framework = f"xpu{version[0:4]}{version[5:6]}" elif _get_privateuse_backend_name() == "npu": from torch_npu.utils.collect_env import get_cann_version + cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2] compute_framework = f"cann{cann_major}{cann_minor}" else: diff --git a/tests/conftest.py b/tests/conftest.py index 82fcea0c..7369646e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,9 +20,7 @@ and torch.version.xpu is not None and torch.xpu.device_count() > 0 ) -has_npu = ( - _get_privateuse_backend_name() == "npu" -) +has_npu = _get_privateuse_backend_name() == "npu" def pytest_addoption(parser): From 135757f0e05a1d62d7a03be064e6e351a57937af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 23 Sep 2025 07:51:52 +0000 Subject: [PATCH 10/10] Fix mypy error --- src/kernels/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernels/utils.py b/src/kernels/utils.py index 5b3dc382..9cfe6fe0 100644 --- a/src/kernels/utils.py +++ b/src/kernels/utils.py @@ -58,7 +58,7 @@ def build_variant() -> str: version = torch.version.xpu compute_framework = f"xpu{version[0:4]}{version[5:6]}" elif _get_privateuse_backend_name() == "npu": - from torch_npu.utils.collect_env import get_cann_version + from torch_npu.utils.collect_env import get_cann_version # type: ignore[import-not-found] cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2] compute_framework = f"cann{cann_major}{cann_minor}"