diff --git a/bittensor/extrinsics/registration.py b/bittensor/extrinsics/registration.py index 8be4963180..879214ad92 100644 --- a/bittensor/extrinsics/registration.py +++ b/bittensor/extrinsics/registration.py @@ -20,7 +20,12 @@ import time from rich.prompt import Confirm from typing import List, Union, Optional, Tuple -from bittensor.utils.registration import POWSolution, create_pow, torch, use_torch +from bittensor.utils.registration import ( + POWSolution, + create_pow, + torch, + log_no_torch_error, +) def register_extrinsic( @@ -100,7 +105,8 @@ def register_extrinsic( ): return False - if not use_torch(): + if not torch: + log_no_torch_error() return False # Attempt rolling registration. @@ -380,8 +386,8 @@ def run_faucet_extrinsic( ): return False, "" - if not use_torch(): - torch.error() + if not torch: + log_no_torch_error() return False, "Requires torch" # Unlock coldkey diff --git a/bittensor/utils/registration.py b/bittensor/utils/registration.py index ff3816ddbb..d6a046f195 100644 --- a/bittensor/utils/registration.py +++ b/bittensor/utils/registration.py @@ -1,16 +1,20 @@ import binascii +import functools import hashlib import math import multiprocessing import os import random import time +import typing from dataclasses import dataclass from datetime import timedelta from queue import Empty, Full from typing import Any, Callable, Dict, List, Optional, Tuple, Union import backoff +import numpy + import bittensor from Crypto.Hash import keccak from rich import console as rich_console @@ -19,62 +23,77 @@ from .formatting import get_human_readable, millify from ._register_cuda import solve_cuda -try: - import torch -except ImportError: - torch = None - def use_torch() -> bool: + """Force the use of torch over numpy for certain operations.""" return True if os.getenv("USE_TORCH") == "1" else False -class Torch: - def __init__(self): - self._transformed = False +def legacy_torch_api_compat(func): + """Decorator to convert between numpy arrays and torch tensors before/after passing them to the function. + Args: + func (function): + Function to be decorated. + Returns: + decorated (function): + Decorated function. + """ - @staticmethod - def _error(): - bittensor.logging.warning( - "This command requires torch. You can install torch for bittensor" - ' with `pip install bittensor[torch]` or `pip install ".[torch]"`' - " if installing from source, and then run the command with USE_TORCH=1 {command}" - ) + @functools.wraps(func) + def decorated(*args, **kwargs): + if use_torch(): + # if argument is a Torch tensor, convert it to numpy + args = [ + arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg + for arg in args + ] + kwargs = { + key: value.cpu().numpy() if isinstance(value, torch.Tensor) else value + for key, value in kwargs.items() + } + ret = func(*args, **kwargs) + if use_torch(): + # if return value is a numpy array, convert it to Torch tensor + if isinstance(ret, numpy.ndarray): + ret = torch.from_numpy(ret) + return ret + + return decorated + + +@functools.cache +def _get_real_torch(): + try: + import torch as _real_torch + except ImportError: + _real_torch = None + return _real_torch - def error(self): - self._error() - def _transform(self): - try: - import torch as real_torch +def log_no_torch_error(): + bittensor.btlogging.error( + "This command requires torch. You can install torch for bittensor" + ' with `pip install bittensor[torch]` or `pip install ".[torch]"`' + " if installing from source, and then run the command with USE_TORCH=1 {command}" + ) - self.__dict__.update(real_torch.__dict__) - self._transformed = True - except ImportError: - self._error() +class LazyLoadedTorch: def __bool__(self): - return False + return bool(_get_real_torch()) def __getattr__(self, name): - if not self._transformed and use_torch(): - self._transform() - if self._transformed: - return getattr(self, name) + if real_torch := _get_real_torch(): + return getattr(real_torch, name) else: - self._error() - - def __call__(self, *args, **kwargs): - if not self._transformed and use_torch(): - self._transform() - if self._transformed: - return self(*args, **kwargs) - else: - self._error() + log_no_torch_error() + raise ImportError("torch not installed") -if not torch or not use_torch(): - torch = Torch() +if typing.TYPE_CHECKING: + import torch +else: + torch = LazyLoadedTorch() class CUDAException(Exception): diff --git a/bittensor/utils/weight_utils.py b/bittensor/utils/weight_utils.py index 109951e14b..9bd8606c9d 100644 --- a/bittensor/utils/weight_utils.py +++ b/bittensor/utils/weight_utils.py @@ -22,12 +22,13 @@ import bittensor from numpy.typing import NDArray from typing import Tuple, List, Union -from bittensor.utils.registration import torch, use_torch +from bittensor.utils.registration import torch, use_torch, legacy_torch_api_compat U32_MAX = 4294967295 U16_MAX = 65535 +@legacy_torch_api_compat def normalize_max_weight( x: Union[NDArray[np.float32], "torch.FloatTensor"], limit: float = 0.1 ) -> Union[NDArray[np.float32], "torch.FloatTensor"]: @@ -43,14 +44,8 @@ def normalize_max_weight( """ epsilon = 1e-7 # For numerical stability after normalization - weights = x.clone() if use_torch() else x.copy() - if use_torch(): - values, _ = torch.sort(weights) - else: - values = np.sort(weights) - - if use_torch() and x.sum() == 0 or len(x) * limit <= 1: - return torch.ones_like(x) / x.size(0) + weights = x.copy() + values = np.sort(weights) if x.sum() == 0 or x.shape[0] * limit <= 1: return np.ones_like(x) / x.shape[0] @@ -61,18 +56,11 @@ def normalize_max_weight( return weights / weights.sum() # Find the cumlative sum and sorted tensor - cumsum = ( - torch.cumsum(estimation, 0) if use_torch() else np.cumsum(estimation, 0) - ) + cumsum = np.cumsum(estimation, 0) # Determine the index of cutoff - estimation_sum_data = [ - (len(values) - i - 1) * estimation[i] for i in range(len(values)) - ] - estimation_sum = ( - torch.tensor(estimation_sum_data) - if use_torch() - else np.array(estimation_sum_data) + estimation_sum = np.array( + [(len(values) - i - 1) * estimation[i] for i in range(len(values))] ) n_values = (estimation / (estimation_sum + cumsum + epsilon) < limit).sum() diff --git a/example.env b/example.env index de5fb400ed..35d405fb58 100644 --- a/example.env +++ b/example.env @@ -1,6 +1,5 @@ -# To use Torch functionality in bittensor, you must set the USE_TORCH flag to 1: -USE_TORCH=1 - -# If set to 0 (or anything else), you will use the numpy functions. -# This is generally what you want unless you have a specific reason for using torch -# such as POW registration or legacy interoperability. \ No newline at end of file +# To use legacy Torch-based of bittensor, you must set USE_TORCH=1 +USE_TORCH=0 +# If set to 0 (or anything else than 1), it will use current, numpy-based, bittensor interface. +# This is generally what you want unless you want legacy interoperability. +# Please note that the legacy interface is deprecated, and is not tested nearly as much. diff --git a/tests/integration_tests/test_cli.py b/tests/integration_tests/test_cli.py index a449604a80..c20c905549 100644 --- a/tests/integration_tests/test_cli.py +++ b/tests/integration_tests/test_cli.py @@ -2090,7 +2090,6 @@ def test_register(self, _): def test_pow_register(self, _): # Not the best way to do this, but I need to finish these tests, and unittest doesn't make this # as simple as pytest - os.environ["USE_TORCH"] = "1" config = self.config config.command = "subnets" config.subcommand = "pow_register" @@ -2114,7 +2113,6 @@ class MockException(Exception): mock_create_wallet.assert_called_once() self.assertEqual(mock_is_stale.call_count, 1) - del os.environ["USE_TORCH"] def test_stake(self, _): amount_to_stake: Balance = Balance.from_tao(0.5) diff --git a/tests/integration_tests/test_subtensor_integration.py b/tests/integration_tests/test_subtensor_integration.py index 5c2ff0cf34..845a73ee7d 100644 --- a/tests/integration_tests/test_subtensor_integration.py +++ b/tests/integration_tests/test_subtensor_integration.py @@ -423,7 +423,6 @@ def test_is_hotkey_registered_not_registered(self): self.assertFalse(registered, msg="Hotkey should not be registered") def test_registration_multiprocessed_already_registered(self): - os.environ["USE_TORCH"] = "1" workblocks_before_is_registered = random.randint(5, 10) # return False each work block but return True after a random number of blocks is_registered_return_values = ( @@ -477,10 +476,8 @@ def test_registration_multiprocessed_already_registered(self): self.subtensor.is_hotkey_registered.call_count == workblocks_before_is_registered + 2 ) - del os.environ["USE_TORCH"] def test_registration_partly_failed(self): - os.environ["USE_TORCH"] = "1" do_pow_register_mock = MagicMock( side_effect=[(False, "Failed"), (False, "Failed"), (True, None)] ) @@ -514,10 +511,8 @@ def is_registered_side_effect(*args, **kwargs): ), msg="Registration should succeed", ) - del os.environ["USE_TORCH"] def test_registration_failed(self): - os.environ["USE_TORCH"] = "1" is_registered_return_values = [False for _ in range(100)] current_block = [i for i in range(0, 100)] mock_neuron = MagicMock() @@ -551,11 +546,9 @@ def test_registration_failed(self): msg="Registration should fail", ) self.assertEqual(mock_create_pow.call_count, 3) - del os.environ["USE_TORCH"] def test_registration_stale_then_continue(self): # verify that after a stale solution, the solve will continue without exiting - os.environ["USE_TORCH"] = "1" class ExitEarly(Exception): pass @@ -596,7 +589,6 @@ class ExitEarly(Exception): 1, msg="only tries to submit once, then exits", ) - del os.environ["USE_TORCH"] def test_defaults_to_finney(self): sub = bittensor.subtensor() diff --git a/tests/unit_tests/extrinsics/test_registration.py b/tests/unit_tests/extrinsics/test_registration.py index 861ce6b462..bad8552b17 100644 --- a/tests/unit_tests/extrinsics/test_registration.py +++ b/tests/unit_tests/extrinsics/test_registration.py @@ -50,11 +50,6 @@ def mock_new_wallet(): return mock -@pytest.fixture(autouse=True) -def set_use_torch_env(monkeypatch): - monkeypatch.setenv("USE_TORCH", "1") - - @pytest.mark.parametrize( "wait_for_inclusion,wait_for_finalization,prompt,cuda,dev_id,tpb,num_processes,update_interval,log_verbose,expected", [ diff --git a/tests/unit_tests/extrinsics/test_root.py b/tests/unit_tests/extrinsics/test_root.py index 131ca2303d..4806a022a8 100644 --- a/tests/unit_tests/extrinsics/test_root.py +++ b/tests/unit_tests/extrinsics/test_root.py @@ -21,11 +21,6 @@ def mock_wallet(): return mock -@pytest.fixture(autouse=True) -def set_use_torch_env(monkeypatch): - monkeypatch.setenv("USE_TORCH", "1") - - @pytest.mark.parametrize( "wait_for_inclusion, wait_for_finalization, hotkey_registered, registration_success, prompt, user_response, expected_result", [ diff --git a/tests/unit_tests/utils/test_weight_utils.py b/tests/unit_tests/utils/test_weight_utils.py index f315edcdce..178ecc6415 100644 --- a/tests/unit_tests/utils/test_weight_utils.py +++ b/tests/unit_tests/utils/test_weight_utils.py @@ -21,6 +21,8 @@ import bittensor.utils.weight_utils as weight_utils import pytest +from bittensor.utils import torch + def test_convert_weight_and_uids(): uids = np.arange(10) @@ -110,6 +112,64 @@ def test_normalize_with_max_weight(): assert np.abs(y - z).sum() < epsilon +def test_normalize_with_max_weight__legacy_torch_api_compat( + force_legacy_torch_compat_api, +): + weights = torch.rand(1000) + wn = weight_utils.normalize_max_weight(weights, limit=0.01) + assert wn.max() <= 0.01 + + weights = torch.zeros(1000) + wn = weight_utils.normalize_max_weight(weights, limit=0.01) + assert wn.max() <= 0.01 + + weights = torch.rand(1000) + wn = weight_utils.normalize_max_weight(weights, limit=0.02) + assert wn.max() <= 0.02 + + weights = torch.zeros(1000) + wn = weight_utils.normalize_max_weight(weights, limit=0.02) + assert wn.max() <= 0.02 + + weights = torch.rand(1000) + wn = weight_utils.normalize_max_weight(weights, limit=0.03) + assert wn.max() <= 0.03 + + weights = torch.zeros(1000) + wn = weight_utils.normalize_max_weight(weights, limit=0.03) + assert wn.max() <= 0.03 + + # Check for Limit + limit = 0.001 + weights = torch.rand(2000) + w = weights / weights.sum() + wn = weight_utils.normalize_max_weight(weights, limit=limit) + assert (w.max() >= limit and (limit - wn.max()).abs() < 0.001) or ( + w.max() < limit and wn.max() < limit + ) + + # Check for Zeros + limit = 0.01 + weights = torch.zeros(2000) + wn = weight_utils.normalize_max_weight(weights, limit=limit) + assert wn.max() == 1 / 2000 + + # Check for Ordering after normalization + weights = torch.rand(100) + wn = weight_utils.normalize_max_weight(weights, limit=1) + assert torch.equal(wn, weights / weights.sum()) + + # Check for eplison changes + eplison = 0.01 + weights, _ = torch.sort(torch.rand(100)) + x = weights / weights.sum() + limit = x[-10] + change = eplison * limit + y = weight_utils.normalize_max_weight(x, limit=limit - change) + z = weight_utils.normalize_max_weight(x, limit=limit + change) + assert (y - z).abs().sum() < eplison + + @pytest.mark.parametrize( "test_id, n, uids, weights, expected", [