diff --git a/bittensor/core/metagraph.py b/bittensor/core/metagraph.py index 5554b32a3d..af758939ea 100644 --- a/bittensor/core/metagraph.py +++ b/bittensor/core/metagraph.py @@ -1,4 +1,5 @@ import asyncio +import contextlib import copy import os import pickle @@ -11,6 +12,7 @@ import numpy as np from async_substrate_interface.errors import SubstrateRequestException from numpy.typing import NDArray +from packaging import version from bittensor.core import settings from bittensor.core.chain_data import ( @@ -143,6 +145,27 @@ def latest_block_path(dir_path: str) -> str: return latest_file_full_path +def safe_globals(): + """ + Context manager to load torch files for version 2.6+ + """ + if version.parse(torch.__version__).release < version.parse("2.6").release: + return contextlib.nullcontext() + + np_core = ( + np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core + ) + allow_list = [ + np_core.multiarray._reconstruct, + np.ndarray, + np.dtype, + type(np.dtype(np.uint32)), + np.dtypes.Float32DType, + bytes, + ] + return torch.serialization.safe_globals(allow_list) + + class MetagraphMixin(ABC): """ The metagraph class is a core component of the Bittensor network, representing the neural graph that forms the @@ -1124,7 +1147,8 @@ def load_from_path(self, dir_path: str) -> "MetagraphMixin": """ graph_file = latest_block_path(dir_path) - state_dict = torch.load(graph_file) + with safe_globals(): + state_dict = torch.load(graph_file) self.n = torch.nn.Parameter(state_dict["n"], requires_grad=False) self.block = torch.nn.Parameter(state_dict["block"], requires_grad=False) self.uids = torch.nn.Parameter(state_dict["uids"], requires_grad=False) @@ -1256,7 +1280,8 @@ def load_from_path(self, dir_path: str) -> "MetagraphMixin": try: import torch as real_torch - state_dict = real_torch.load(graph_filename) + with safe_globals(): + state_dict = real_torch.load(graph_filename) for key in METAGRAPH_STATE_DICT_NDARRAY_KEYS: state_dict[key] = state_dict[key].detach().numpy() del real_torch diff --git a/pyproject.toml b/pyproject.toml index 686fad4301..76cba69863 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,10 +59,10 @@ dev = [ "aioresponses==0.7.6", "factory-boy==3.3.0", "types-requests", - "torch>=1.13.1,<2.6.0" + "torch>=1.13.1,<3.0" ] torch = [ - "torch>=1.13.1,<2.6.0" + "torch>=1.13.1,<3.0" ] cli = [ "bittensor-cli>=9.0.2" diff --git a/requirements/torch.txt b/requirements/torch.txt index 1abaa00adc..07a6bcf5aa 100644 --- a/requirements/torch.txt +++ b/requirements/torch.txt @@ -1 +1 @@ -torch>=1.13.1,<2.6.0 +torch>=1.13.1,<3.0