Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions bittensor/extrinsics/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
import time
from rich.prompt import Confirm
from typing import List, Union, Optional, Tuple
from bittensor.utils.registration import POWSolution, create_pow, torch, use_torch
from bittensor.utils.registration import (
POWSolution,
create_pow,
torch,
log_no_torch_error,
)


def register_extrinsic(
Expand Down Expand Up @@ -100,7 +105,8 @@ def register_extrinsic(
):
return False

if not use_torch():
if not torch:
log_no_torch_error()
return False

# Attempt rolling registration.
Expand Down Expand Up @@ -380,8 +386,8 @@ def run_faucet_extrinsic(
):
return False, ""

if not use_torch():
torch.error()
if not torch:
log_no_torch_error()
return False, "Requires torch"

# Unlock coldkey
Expand Down
99 changes: 59 additions & 40 deletions bittensor/utils/registration.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import binascii
import functools
import hashlib
import math
import multiprocessing
import os
import random
import time
import typing
from dataclasses import dataclass
from datetime import timedelta
from queue import Empty, Full
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import backoff
import numpy

import bittensor
from Crypto.Hash import keccak
from rich import console as rich_console
Expand All @@ -19,62 +23,77 @@
from .formatting import get_human_readable, millify
from ._register_cuda import solve_cuda

try:
import torch
except ImportError:
torch = None


def use_torch() -> bool:
"""Force the use of torch over numpy for certain operations."""
return True if os.getenv("USE_TORCH") == "1" else False


class Torch:
def __init__(self):
self._transformed = False
def legacy_torch_api_compat(func):
"""Decorator to convert between numpy arrays and torch tensors before/after passing them to the function.
Args:
func (function):
Function to be decorated.
Returns:
decorated (function):
Decorated function.
"""

@staticmethod
def _error():
bittensor.logging.warning(
"This command requires torch. You can install torch for bittensor"
' with `pip install bittensor[torch]` or `pip install ".[torch]"`'
" if installing from source, and then run the command with USE_TORCH=1 {command}"
)
@functools.wraps(func)
def decorated(*args, **kwargs):
if use_torch():
# if argument is a Torch tensor, convert it to numpy
args = [
arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg
for arg in args
]
kwargs = {
key: value.cpu().numpy() if isinstance(value, torch.Tensor) else value
for key, value in kwargs.items()
}
ret = func(*args, **kwargs)
if use_torch():
# if return value is a numpy array, convert it to Torch tensor
if isinstance(ret, numpy.ndarray):
ret = torch.from_numpy(ret)
return ret

return decorated


@functools.cache
def _get_real_torch():
try:
import torch as _real_torch
except ImportError:
_real_torch = None
return _real_torch

def error(self):
self._error()

def _transform(self):
try:
import torch as real_torch
def log_no_torch_error():
bittensor.btlogging.error(
"This command requires torch. You can install torch for bittensor"
' with `pip install bittensor[torch]` or `pip install ".[torch]"`'
" if installing from source, and then run the command with USE_TORCH=1 {command}"
)

self.__dict__.update(real_torch.__dict__)
self._transformed = True
except ImportError:
self._error()

class LazyLoadedTorch:
def __bool__(self):
return False
return bool(_get_real_torch())

def __getattr__(self, name):
if not self._transformed and use_torch():
self._transform()
if self._transformed:
return getattr(self, name)
if real_torch := _get_real_torch():
return getattr(real_torch, name)
else:
self._error()

def __call__(self, *args, **kwargs):
if not self._transformed and use_torch():
self._transform()
if self._transformed:
return self(*args, **kwargs)
else:
self._error()
log_no_torch_error()
raise ImportError("torch not installed")


if not torch or not use_torch():
torch = Torch()
if typing.TYPE_CHECKING:
import torch
else:
torch = LazyLoadedTorch()


class CUDAException(Exception):
Expand Down
26 changes: 7 additions & 19 deletions bittensor/utils/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@
import bittensor
from numpy.typing import NDArray
from typing import Tuple, List, Union
from bittensor.utils.registration import torch, use_torch
from bittensor.utils.registration import torch, use_torch, legacy_torch_api_compat

U32_MAX = 4294967295
U16_MAX = 65535


@legacy_torch_api_compat
def normalize_max_weight(
x: Union[NDArray[np.float32], "torch.FloatTensor"], limit: float = 0.1
) -> Union[NDArray[np.float32], "torch.FloatTensor"]:
Expand All @@ -43,14 +44,8 @@ def normalize_max_weight(
"""
epsilon = 1e-7 # For numerical stability after normalization

weights = x.clone() if use_torch() else x.copy()
if use_torch():
values, _ = torch.sort(weights)
else:
values = np.sort(weights)

if use_torch() and x.sum() == 0 or len(x) * limit <= 1:
return torch.ones_like(x) / x.size(0)
weights = x.copy()
values = np.sort(weights)

if x.sum() == 0 or x.shape[0] * limit <= 1:
return np.ones_like(x) / x.shape[0]
Expand All @@ -61,18 +56,11 @@ def normalize_max_weight(
return weights / weights.sum()

# Find the cumlative sum and sorted tensor
cumsum = (
torch.cumsum(estimation, 0) if use_torch() else np.cumsum(estimation, 0)
)
cumsum = np.cumsum(estimation, 0)

# Determine the index of cutoff
estimation_sum_data = [
(len(values) - i - 1) * estimation[i] for i in range(len(values))
]
estimation_sum = (
torch.tensor(estimation_sum_data)
if use_torch()
else np.array(estimation_sum_data)
estimation_sum = np.array(
[(len(values) - i - 1) * estimation[i] for i in range(len(values))]
)
n_values = (estimation / (estimation_sum + cumsum + epsilon) < limit).sum()

Expand Down
11 changes: 5 additions & 6 deletions example.env
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# To use Torch functionality in bittensor, you must set the USE_TORCH flag to 1:
USE_TORCH=1

# If set to 0 (or anything else), you will use the numpy functions.
# This is generally what you want unless you have a specific reason for using torch
# such as POW registration or legacy interoperability.
# To use legacy Torch-based of bittensor, you must set USE_TORCH=1
USE_TORCH=0
# If set to 0 (or anything else than 1), it will use current, numpy-based, bittensor interface.
# This is generally what you want unless you want legacy interoperability.
# Please note that the legacy interface is deprecated, and is not tested nearly as much.
2 changes: 0 additions & 2 deletions tests/integration_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2090,7 +2090,6 @@ def test_register(self, _):
def test_pow_register(self, _):
# Not the best way to do this, but I need to finish these tests, and unittest doesn't make this
# as simple as pytest
os.environ["USE_TORCH"] = "1"
config = self.config
config.command = "subnets"
config.subcommand = "pow_register"
Expand All @@ -2114,7 +2113,6 @@ class MockException(Exception):
mock_create_wallet.assert_called_once()

self.assertEqual(mock_is_stale.call_count, 1)
del os.environ["USE_TORCH"]

def test_stake(self, _):
amount_to_stake: Balance = Balance.from_tao(0.5)
Expand Down
8 changes: 0 additions & 8 deletions tests/integration_tests/test_subtensor_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ def test_is_hotkey_registered_not_registered(self):
self.assertFalse(registered, msg="Hotkey should not be registered")

def test_registration_multiprocessed_already_registered(self):
os.environ["USE_TORCH"] = "1"
workblocks_before_is_registered = random.randint(5, 10)
# return False each work block but return True after a random number of blocks
is_registered_return_values = (
Expand Down Expand Up @@ -477,10 +476,8 @@ def test_registration_multiprocessed_already_registered(self):
self.subtensor.is_hotkey_registered.call_count
== workblocks_before_is_registered + 2
)
del os.environ["USE_TORCH"]

def test_registration_partly_failed(self):
os.environ["USE_TORCH"] = "1"
do_pow_register_mock = MagicMock(
side_effect=[(False, "Failed"), (False, "Failed"), (True, None)]
)
Expand Down Expand Up @@ -514,10 +511,8 @@ def is_registered_side_effect(*args, **kwargs):
),
msg="Registration should succeed",
)
del os.environ["USE_TORCH"]

def test_registration_failed(self):
os.environ["USE_TORCH"] = "1"
is_registered_return_values = [False for _ in range(100)]
current_block = [i for i in range(0, 100)]
mock_neuron = MagicMock()
Expand Down Expand Up @@ -551,11 +546,9 @@ def test_registration_failed(self):
msg="Registration should fail",
)
self.assertEqual(mock_create_pow.call_count, 3)
del os.environ["USE_TORCH"]

def test_registration_stale_then_continue(self):
# verify that after a stale solution, the solve will continue without exiting
os.environ["USE_TORCH"] = "1"

class ExitEarly(Exception):
pass
Expand Down Expand Up @@ -596,7 +589,6 @@ class ExitEarly(Exception):
1,
msg="only tries to submit once, then exits",
)
del os.environ["USE_TORCH"]

def test_defaults_to_finney(self):
sub = bittensor.subtensor()
Expand Down
5 changes: 0 additions & 5 deletions tests/unit_tests/extrinsics/test_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,6 @@ def mock_new_wallet():
return mock


@pytest.fixture(autouse=True)
def set_use_torch_env(monkeypatch):
monkeypatch.setenv("USE_TORCH", "1")


@pytest.mark.parametrize(
"wait_for_inclusion,wait_for_finalization,prompt,cuda,dev_id,tpb,num_processes,update_interval,log_verbose,expected",
[
Expand Down
5 changes: 0 additions & 5 deletions tests/unit_tests/extrinsics/test_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ def mock_wallet():
return mock


@pytest.fixture(autouse=True)
def set_use_torch_env(monkeypatch):
monkeypatch.setenv("USE_TORCH", "1")


@pytest.mark.parametrize(
"wait_for_inclusion, wait_for_finalization, hotkey_registered, registration_success, prompt, user_response, expected_result",
[
Expand Down
Loading