Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions bittensor/core/metagraph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextlib
import copy
import os
import pickle
Expand All @@ -11,6 +12,7 @@
import numpy as np
from async_substrate_interface.errors import SubstrateRequestException
from numpy.typing import NDArray
from packaging import version

from bittensor.core import settings
from bittensor.core.chain_data import (
Expand Down Expand Up @@ -143,6 +145,27 @@ def latest_block_path(dir_path: str) -> str:
return latest_file_full_path


def safe_globals():
"""
Context manager to load torch files for version 2.6+
"""
if version.parse(torch.__version__).release < version.parse("2.6").release:
return contextlib.nullcontext()

np_core = (
np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core
)
allow_list = [
np_core.multiarray._reconstruct,
np.ndarray,
np.dtype,
type(np.dtype(np.uint32)),
np.dtypes.Float32DType,
bytes,
]
return torch.serialization.safe_globals(allow_list)


class MetagraphMixin(ABC):
"""
The metagraph class is a core component of the Bittensor network, representing the neural graph that forms the
Expand Down Expand Up @@ -1124,7 +1147,8 @@ def load_from_path(self, dir_path: str) -> "MetagraphMixin":
"""

graph_file = latest_block_path(dir_path)
state_dict = torch.load(graph_file)
with safe_globals():
state_dict = torch.load(graph_file)
self.n = torch.nn.Parameter(state_dict["n"], requires_grad=False)
self.block = torch.nn.Parameter(state_dict["block"], requires_grad=False)
self.uids = torch.nn.Parameter(state_dict["uids"], requires_grad=False)
Expand Down Expand Up @@ -1256,7 +1280,8 @@ def load_from_path(self, dir_path: str) -> "MetagraphMixin":
try:
import torch as real_torch

state_dict = real_torch.load(graph_filename)
with safe_globals():
state_dict = real_torch.load(graph_filename)
for key in METAGRAPH_STATE_DICT_NDARRAY_KEYS:
state_dict[key] = state_dict[key].detach().numpy()
del real_torch
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ dev = [
"aioresponses==0.7.6",
"factory-boy==3.3.0",
"types-requests",
"torch>=1.13.1,<2.6.0"
"torch>=1.13.1,<3.0"
]
torch = [
"torch>=1.13.1,<2.6.0"
"torch>=1.13.1,<3.0"
]
cli = [
"bittensor-cli>=9.0.2"
Expand Down
2 changes: 1 addition & 1 deletion requirements/torch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
torch>=1.13.1,<2.6.0
torch>=1.13.1,<3.0