diff --git a/kernels/src/kernels/__init__.py b/kernels/src/kernels/__init__.py index 248f712f..17f7cef8 100644 --- a/kernels/src/kernels/__init__.py +++ b/kernels/src/kernels/__init__.py @@ -23,6 +23,7 @@ ) from kernels.utils import ( get_kernel, + get_loaded_kernels, get_local_kernel, get_locked_kernel, has_kernel, @@ -45,6 +46,7 @@ "LockedLayerRepository", "Mode", "get_kernel", + "get_loaded_kernels", "get_local_kernel", "get_locked_kernel", "has_kernel", diff --git a/kernels/src/kernels/utils.py b/kernels/src/kernels/utils.py index fbe4dbce..11feaa56 100644 --- a/kernels/src/kernels/utils.py +++ b/kernels/src/kernels/utils.py @@ -12,6 +12,7 @@ from importlib.metadata import Distribution from pathlib import Path from types import ModuleType +from typing import NamedTuple from huggingface_hub import HfApi, constants @@ -33,6 +34,29 @@ KNOWN_BACKENDS = {"cpu", "cuda", "metal", "neuron", "rocm", "xpu", "npu"} +class RepoInfos(NamedTuple): + repo_id: str + revision: str | None + version: int | str | None + backend: str | None + + +class LoadedKernel(NamedTuple): + variant_path: Path + package_name: str + module_name: str + torch_namespace: str | None + repo_infos: RepoInfos | None + + +_loaded_kernels: dict[str, LoadedKernel] = {} + + +def get_loaded_kernels() -> dict[str, LoadedKernel]: + """Returns a copy of the loaded kernels registry (`module_name -> LoadedKernel` mapping).""" + return _loaded_kernels.copy() + + def _get_cache_dir() -> str | None: """Returns the kernels cache directory.""" cache_dir = os.environ.get("HF_KERNELS_CACHE", None) @@ -71,7 +95,9 @@ def _parse_local_kernel_overrides(local_kernels: str) -> dict[str, Path]: CACHE_DIR: str | None = _get_cache_dir() -def _import_from_path(module_name: str, variant_path: Path) -> ModuleType: +def _import_from_path( + module_name: str, variant_path: Path, _repo_infos: RepoInfos | None = None +) -> ModuleType: metadata = Metadata.load_from_variant(variant_path) validate_dependencies(module_name, metadata.python_depends, _backend()) @@ -83,6 +109,7 @@ def _import_from_path(module_name: str, variant_path: Path) -> ModuleType: # it would also be used for other imports. So, we make a module name that # depends on the path for it to be unique using the hex-encoded hash of # the path. + package_name = module_name path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value) module_name = f"{module_name}_{path_hash}" spec = importlib.util.spec_from_file_location(module_name, file_path) @@ -93,6 +120,16 @@ def _import_from_path(module_name: str, variant_path: Path) -> ModuleType: raise ImportError(f"Cannot load module {module_name} from spec") sys.modules[module_name] = module spec.loader.exec_module(module) # type: ignore + torch_namespace: str | None = None + if (ops := sys.modules.get(f"{module_name}._ops")) is not None: + torch_namespace = getattr(ops.ops, "name", None) + _loaded_kernels[module_name] = LoadedKernel( + torch_namespace=torch_namespace, + variant_path=variant_path, + package_name=package_name, + module_name=module_name, + repo_infos=_repo_infos, + ) return module @@ -282,7 +319,13 @@ def get_kernel( package_name, variant_path = install_kernel( repo_id, revision=revision, backend=backend, user_agent=user_agent ) - return _import_from_path(package_name, variant_path) + repo_infos = RepoInfos( + repo_id=repo_id, + revision=revision, + version=version, + backend=backend, + ) + return _import_from_path(package_name, variant_path, _repo_infos=repo_infos) def get_local_kernel(