Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kernels/src/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from kernels.utils import (
get_kernel,
get_loaded_kernels,
get_local_kernel,
get_locked_kernel,
has_kernel,
Expand All @@ -45,6 +46,7 @@
"LockedLayerRepository",
"Mode",
"get_kernel",
"get_loaded_kernels",
"get_local_kernel",
"get_locked_kernel",
"has_kernel",
Expand Down
47 changes: 45 additions & 2 deletions kernels/src/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from importlib.metadata import Distribution
from pathlib import Path
from types import ModuleType
from typing import NamedTuple

from huggingface_hub import HfApi, constants

Expand All @@ -33,6 +34,29 @@
KNOWN_BACKENDS = {"cpu", "cuda", "metal", "neuron", "rocm", "xpu", "npu"}


class RepoInfos(NamedTuple):
repo_id: str
revision: str | None
version: int | str | None
backend: str | None


class LoadedKernel(NamedTuple):
variant_path: Path
package_name: str
module_name: str
torch_namespace: str | None
repo_infos: RepoInfos | None


_loaded_kernels: dict[str, LoadedKernel] = {}


def get_loaded_kernels() -> dict[str, LoadedKernel]:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is still a TODO, but I am curious what the use case of this function is, since it maps internal module names. I don't think we want to expose this to the user?

Copy link
Copy Markdown
Member

@danieldk danieldk Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the loaded kernels need to be externally probed, maybe we can give the kernel a more predictable name in sys.modules (e.g. a kernels-specific prefix)? In that way the module table could be probed and it would be more resistant against e.g. someone unloading modules.

"""Returns a copy of the loaded kernels registry (`module_name -> LoadedKernel` mapping)."""
return _loaded_kernels.copy()


def _get_cache_dir() -> str | None:
"""Returns the kernels cache directory."""
cache_dir = os.environ.get("HF_KERNELS_CACHE", None)
Expand Down Expand Up @@ -71,7 +95,9 @@ def _parse_local_kernel_overrides(local_kernels: str) -> dict[str, Path]:
CACHE_DIR: str | None = _get_cache_dir()


def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
def _import_from_path(
module_name: str, variant_path: Path, _repo_infos: RepoInfos | None = None
) -> ModuleType:
metadata = Metadata.load_from_variant(variant_path)
validate_dependencies(module_name, metadata.python_depends, _backend())

Expand All @@ -83,6 +109,7 @@ def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
# it would also be used for other imports. So, we make a module name that
# depends on the path for it to be unique using the hex-encoded hash of
# the path.
package_name = module_name
path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path)).value)
module_name = f"{module_name}_{path_hash}"
spec = importlib.util.spec_from_file_location(module_name, file_path)
Expand All @@ -93,6 +120,16 @@ def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
raise ImportError(f"Cannot load module {module_name} from spec")
sys.modules[module_name] = module
spec.loader.exec_module(module) # type: ignore
torch_namespace: str | None = None
if (ops := sys.modules.get(f"{module_name}._ops")) is not None:
torch_namespace = getattr(ops.ops, "name", None)
_loaded_kernels[module_name] = LoadedKernel(
torch_namespace=torch_namespace,
variant_path=variant_path,
package_name=package_name,
module_name=module_name,
repo_infos=_repo_infos,
)
return module


Expand Down Expand Up @@ -282,7 +319,13 @@ def get_kernel(
package_name, variant_path = install_kernel(
repo_id, revision=revision, backend=backend, user_agent=user_agent
)
return _import_from_path(package_name, variant_path)
repo_infos = RepoInfos(
repo_id=repo_id,
revision=revision,
version=version,
backend=backend,
)
return _import_from_path(package_name, variant_path, _repo_infos=repo_infos)


def get_local_kernel(
Expand Down
Loading