diff --git a/bittensor/core/async_subtensor.py b/bittensor/core/async_subtensor.py
index fc59fc121a..928283a26e 100644
--- a/bittensor/core/async_subtensor.py
+++ b/bittensor/core/async_subtensor.py
@@ -846,6 +846,48 @@ async def neurons_lite(
return NeuronInfoLite.list_from_vec_u8(hex_to_bytes(hex_bytes_result))
+ async def get_neuron_for_pubkey_and_subnet(
+ self,
+ hotkey_ss58: str,
+ netuid: int,
+ block_hash: Optional[str] = None,
+ reuse_block: bool = False,
+ ) -> "NeuronInfo":
+ """
+ Retrieves information about a neuron based on its public key (hotkey SS58 address) and the specific subnet UID (netuid). This function provides detailed neuron information for a particular subnet within the Bittensor network.
+
+ Args:
+ hotkey_ss58 (str): The ``SS58`` address of the neuron's hotkey.
+ netuid (int): The unique identifier of the subnet.
+ block_hash (Optional[int]): The blockchain block number at which to perform the query.
+ reuse_block (bool): Whether to reuse the last-used blockchain block hash.
+
+ Returns:
+ Optional[bittensor.core.chain_data.neuron_info.NeuronInfo]: Detailed information about the neuron if found, ``None`` otherwise.
+
+ This function is crucial for accessing specific neuron data and understanding its status, stake, and other attributes within a particular subnet of the Bittensor ecosystem.
+ """
+ uid = await self.substrate.query(
+ module="SubtensorModule",
+ storage_function="Uids",
+ params=[netuid, hotkey_ss58],
+ block_hash=block_hash,
+ reuse_block_hash=reuse_block,
+ )
+ if uid is None:
+ return NeuronInfo.get_null_neuron()
+
+ params = [netuid, uid]
+ json_body = await self.substrate.rpc_request(
+ method="neuronInfo_getNeuron",
+ params=params,
+ )
+
+ if not (result := json_body.get("result", None)):
+ return NeuronInfo.get_null_neuron()
+
+ return NeuronInfo.from_vec_u8(bytes(result))
+
async def neuron_for_uid(
self, uid: Optional[int], netuid: int, block_hash: Optional[str] = None
) -> NeuronInfo:
diff --git a/bittensor/core/extrinsics/async_registration.py b/bittensor/core/extrinsics/async_registration.py
index a0901a5639..576141e939 100644
--- a/bittensor/core/extrinsics/async_registration.py
+++ b/bittensor/core/extrinsics/async_registration.py
@@ -1,402 +1,29 @@
+"""
+This module provides functionalities for registering a wallet with the subtensor network using Proof-of-Work (PoW).
+
+Extrinsics:
+- register_extrinsic: Registers the wallet to the subnet.
+- run_faucet_extrinsic: Runs a continual POW to get a faucet of TAO on the test net.
+"""
+
import asyncio
-import binascii
-import functools
-import hashlib
-import io
-import math
-import multiprocessing as mp
-import os
-import random
-import subprocess
import time
-from contextlib import redirect_stdout
-from dataclasses import dataclass
-from datetime import timedelta
-from multiprocessing import Process, Event, Lock, Array, Value, Queue
-from multiprocessing.queues import Queue as Queue_Type
-from queue import Empty, Full
-from typing import Optional, Union, TYPE_CHECKING, Callable, Any
+from typing import Optional, Union, TYPE_CHECKING
-import backoff
-import numpy as np
-from Crypto.Hash import keccak
from bittensor_wallet import Wallet
-from rich.console import Console
-from rich.status import Status
-from substrateinterface.exceptions import SubstrateRequestException
-from bittensor.core.chain_data import NeuronInfo
from bittensor.utils import format_error_message, unlock_key
from bittensor.utils.btlogging import logging
-from bittensor.utils.formatting import millify, get_human_readable
-
-if TYPE_CHECKING:
- from bittensor.core.async_subtensor import AsyncSubtensor
-
-
-# TODO: compair and remove existing code (bittensor.utils.registration)
-
-
-def use_torch() -> bool:
- """Force the use of torch over numpy for certain operations."""
- return True if os.getenv("USE_TORCH") == "1" else False
-
-
-def legacy_torch_api_compat(func: Callable):
- """
- Convert function operating on numpy Input&Output to legacy torch Input&Output API if `use_torch()` is True.
-
- Args:
- func: Function with numpy Input/Output to be decorated.
-
- Returns:
- Decorated function
- """
-
- @functools.wraps(func)
- def decorated(*args, **kwargs):
- if use_torch():
- # if argument is a Torch tensor, convert it to numpy
- args = [
- arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg
- for arg in args
- ]
- kwargs = {
- key: value.cpu().numpy() if isinstance(value, torch.Tensor) else value
- for key, value in kwargs.items()
- }
- ret = func(*args, **kwargs)
- if use_torch():
- # if return value is a numpy array, convert it to Torch tensor
- if isinstance(ret, np.ndarray):
- ret = torch.from_numpy(ret)
- return ret
-
- return decorated
-
-
-@functools.cache
-def _get_real_torch():
- try:
- import torch as _real_torch
- except ImportError:
- _real_torch = None
- return _real_torch
-
-
-def log_no_torch_error():
- logging.info(
- "This command requires torch. You can install torch with `pip install torch` and run the command again."
- )
-
-
-@dataclass
-class POWSolution:
- """A solution to the registration PoW problem."""
-
- nonce: int
- block_number: int
- difficulty: int
- seal: bytes
-
- async def is_stale(self, subtensor: "AsyncSubtensor") -> bool:
- """
- Returns True if the POW is stale.
-
- This means the block the POW is solved for is within 3 blocks of the current block.
- """
- current_block = await subtensor.substrate.get_block_number(None)
- return self.block_number < current_block - 3
-
-
-@dataclass
-class RegistrationStatistics:
- """Statistics for a registration."""
-
- time_spent_total: float
- rounds_total: int
- time_average: float
- time_spent: float
- hash_rate_perpetual: float
- hash_rate: float
- difficulty: int
- block_number: int
- block_hash: str
-
-
-class RegistrationStatisticsLogger:
- """Logs statistics for a registration."""
-
- console: Console
- status: Optional[Status]
-
- def __init__(
- self, console_: Optional["Console"] = None, output_in_place: bool = True
- ) -> None:
- if console_ is None:
- console_ = Console()
- self.console = console_
-
- if output_in_place:
- self.status = self.console.status("Solving")
- else:
- self.status = None
-
- def start(self) -> None:
- if self.status is not None:
- self.status.start()
-
- def stop(self) -> None:
- if self.status is not None:
- self.status.stop()
-
- @classmethod
- def get_status_message(
- cls, stats: RegistrationStatistics, verbose: bool = False
- ) -> str:
- """Provides a message of the current status of the block solving as a str for a logger or stdout."""
- message = (
- "Solving\n"
- + f"Time Spent (total): [bold white]{timedelta(seconds=stats.time_spent_total)}[/bold white]\n"
- + (
- f"Time Spent This Round: {timedelta(seconds=stats.time_spent)}\n"
- + f"Time Spent Average: {timedelta(seconds=stats.time_average)}\n"
- if verbose
- else ""
- )
- + f"Registration Difficulty: [bold white]{millify(stats.difficulty)}[/bold white]\n"
- + f"Iters (Inst/Perp): [bold white]{get_human_readable(stats.hash_rate, 'H')}/s / "
- + f"{get_human_readable(stats.hash_rate_perpetual, 'H')}/s[/bold white]\n"
- + f"Block Number: [bold white]{stats.block_number}[/bold white]\n"
- + f"Block Hash: [bold white]{stats.block_hash.encode('utf-8')}[/bold white]\n"
- )
- return message
-
- def update(self, stats: RegistrationStatistics, verbose: bool = False) -> None:
- """Passes the current status to the logger."""
- if self.status is not None:
- self.status.update(self.get_status_message(stats, verbose=verbose))
- else:
- self.console.log(self.get_status_message(stats, verbose=verbose))
-
-
-class _SolverBase(Process):
- """
- A process that solves the registration PoW problem.
-
- Args:
- proc_num: The number of the process being created.
- num_proc: The total number of processes running.
- update_interval: The number of nonces to try to solve before checking for a new block.
- finished_queue: The queue to put the process number when a process finishes each update_interval. Used for calculating the average time per update_interval across all processes.
- solution_queue: The queue to put the solution the process has found during the pow solve.
- stop_event: The event to set by the main process when all the solver processes should stop. The solver process will check for the event after each update_interval. The solver process will stop when the event is set. Used to stop the solver processes when a solution is found.
- curr_block: The array containing this process's current block hash. The main process will set the array to the new block hash when a new block is finalized in the network. The solver process will get the new block hash from this array when newBlockEvent is set
- curr_block_num: The value containing this process's current block number. The main process will set the value to the new block number when a new block is finalized in the network. The solver process will get the new block number from this value when new_block_event is set.
- curr_diff: The array containing this process's current difficulty. The main process will set the array to the new difficulty when a new block is finalized in the network. The solver process will get the new difficulty from this array when newBlockEvent is set.
- check_block: The lock to prevent this process from getting the new block data while the main process is updating the data.
- limit: The limit of the pow solve for a valid solution.
-
- Returns:
- new_block_event: The event to set by the main process when a new block is finalized in the network. The solver process will check for the event after each update_interval. The solver process will get the new block hash and difficulty and start solving for a new nonce.
- """
-
- proc_num: int
- num_proc: int
- update_interval: int
- finished_queue: Queue_Type
- solution_queue: Queue_Type
- new_block_event: Event
- stop_event: Event
- hotkey_bytes: bytes
- curr_block: Array
- curr_block_num: Value
- curr_diff: Array
- check_block: Lock
- limit: int
-
- def __init__(
- self,
- proc_num,
- num_proc,
- update_interval,
- finished_queue,
- solution_queue,
- stop_event,
- curr_block,
- curr_block_num,
- curr_diff,
- check_block,
- limit,
- ):
- Process.__init__(self, daemon=True)
- self.proc_num = proc_num
- self.num_proc = num_proc
- self.update_interval = update_interval
- self.finished_queue = finished_queue
- self.solution_queue = solution_queue
- self.new_block_event = Event()
- self.new_block_event.clear()
- self.curr_block = curr_block
- self.curr_block_num = curr_block_num
- self.curr_diff = curr_diff
- self.check_block = check_block
- self.stop_event = stop_event
- self.limit = limit
-
- def run(self):
- raise NotImplementedError("_SolverBase is an abstract class")
-
- @staticmethod
- def create_shared_memory() -> tuple[Array, Value, Array]:
- """Creates shared memory for the solver processes to use."""
- curr_block = Array("h", 32, lock=True) # byte array
- curr_block_num = Value("i", 0, lock=True) # int
- curr_diff = Array("Q", [0, 0], lock=True) # [high, low]
-
- return curr_block, curr_block_num, curr_diff
-
-
-class _Solver(_SolverBase):
- """Performs POW Solution."""
-
- def run(self):
- block_number: int
- block_and_hotkey_hash_bytes: bytes
- block_difficulty: int
- nonce_limit = int(math.pow(2, 64)) - 1
-
- # Start at random nonce
- nonce_start = random.randint(0, nonce_limit)
- nonce_end = nonce_start + self.update_interval
- while not self.stop_event.is_set():
- if self.new_block_event.is_set():
- with self.check_block:
- block_number = self.curr_block_num.value
- block_and_hotkey_hash_bytes = bytes(self.curr_block)
- block_difficulty = _registration_diff_unpack(self.curr_diff)
-
- self.new_block_event.clear()
-
- # Do a block of nonces
- solution = _solve_for_nonce_block(
- nonce_start,
- nonce_end,
- block_and_hotkey_hash_bytes,
- block_difficulty,
- self.limit,
- block_number,
- )
- if solution is not None:
- self.solution_queue.put(solution)
-
- try:
- # Send time
- self.finished_queue.put_nowait(self.proc_num)
- except Full:
- pass
-
- nonce_start = random.randint(0, nonce_limit)
- nonce_start = nonce_start % nonce_limit
- nonce_end = nonce_start + self.update_interval
-
-
-class _CUDASolver(_SolverBase):
- """Performs POW Solution using CUDA."""
-
- dev_id: int
- tpb: int
-
- def __init__(
- self,
- proc_num,
- num_proc,
- update_interval,
- finished_queue,
- solution_queue,
- stop_event,
- curr_block,
- curr_block_num,
- curr_diff,
- check_block,
- limit,
- dev_id: int,
- tpb: int,
- ):
- super().__init__(
- proc_num,
- num_proc,
- update_interval,
- finished_queue,
- solution_queue,
- stop_event,
- curr_block,
- curr_block_num,
- curr_diff,
- check_block,
- limit,
- )
- self.dev_id = dev_id
- self.tpb = tpb
-
- def run(self):
- block_number: int = 0 # dummy value
- block_and_hotkey_hash_bytes: bytes = b"0" * 32 # dummy value
- block_difficulty: int = int(math.pow(2, 64)) - 1 # dummy value
- nonce_limit = int(math.pow(2, 64)) - 1 # U64MAX
-
- # Start at random nonce
- nonce_start = random.randint(0, nonce_limit)
- while not self.stop_event.is_set():
- if self.new_block_event.is_set():
- with self.check_block:
- block_number = self.curr_block_num.value
- block_and_hotkey_hash_bytes = bytes(self.curr_block)
- block_difficulty = _registration_diff_unpack(self.curr_diff)
-
- self.new_block_event.clear()
-
- # Do a block of nonces
- solution = _solve_for_nonce_block_cuda(
- nonce_start,
- self.update_interval,
- block_and_hotkey_hash_bytes,
- block_difficulty,
- self.limit,
- block_number,
- self.dev_id,
- self.tpb,
- )
- if solution is not None:
- self.solution_queue.put(solution)
-
- try:
- # Signal that a nonce_block was finished using queue
- # send our proc_num
- self.finished_queue.put(self.proc_num)
- except Full:
- pass
-
- # increase nonce by number of nonces processed
- nonce_start += self.update_interval * self.tpb
- nonce_start = nonce_start % nonce_limit
-
-
-class LazyLoadedTorch:
- def __bool__(self):
- return bool(_get_real_torch())
-
- def __getattr__(self, name):
- if real_torch := _get_real_torch():
- return getattr(real_torch, name)
- else:
- log_no_torch_error()
- raise ImportError("torch not installed")
-
+from bittensor.utils.registration import log_no_torch_error, create_pow_async
+# For annotation and lazy import purposes
if TYPE_CHECKING:
import torch
+ from bittensor.core.async_subtensor import AsyncSubtensor
+ from bittensor.utils.registration import POWSolution
else:
+ from bittensor.utils.registration.pow import LazyLoadedTorch
+
torch = LazyLoadedTorch()
@@ -408,21 +35,6 @@ class MaxAttemptsException(Exception):
"""Raised when the POW Solver has reached the max number of attempts."""
-async def is_hotkey_registered(
- subtensor: "AsyncSubtensor", netuid: int, hotkey_ss58: str
-) -> bool:
- """Checks to see if the hotkey is registered on a given netuid"""
- _result = await subtensor.substrate.query(
- module="SubtensorModule",
- storage_function="Uids",
- params=[netuid, hotkey_ss58],
- )
- if _result is not None:
- return True
- else:
- return False
-
-
async def register_extrinsic(
subtensor: "AsyncSubtensor",
wallet: "Wallet",
@@ -459,24 +71,6 @@ async def register_extrinsic(
`True` if extrinsic was finalized or included in the block. If we did not wait for finalization/inclusion, the response is `True`.
"""
- async def get_neuron_for_pubkey_and_subnet():
- uid = await subtensor.substrate.query(
- "SubtensorModule", "Uids", [netuid, wallet.hotkey.ss58_address]
- )
- if uid is None:
- return NeuronInfo.get_null_neuron()
-
- params = [netuid, uid]
- json_body = await subtensor.substrate.rpc_request(
- method="neuronInfo_getNeuron",
- params=params,
- )
-
- if not (result := json_body.get("result", None)):
- return NeuronInfo.get_null_neuron()
-
- return NeuronInfo.from_vec_u8(bytes(result))
-
logging.debug("Checking subnet status")
if not await subtensor.subnet_exists(netuid):
logging.error(
@@ -487,7 +81,10 @@ async def get_neuron_for_pubkey_and_subnet():
logging.info(
f":satellite: Checking Account on subnet {netuid} ..."
)
- neuron = await get_neuron_for_pubkey_and_subnet()
+ neuron = await subtensor.get_neuron_for_pubkey_and_subnet(
+ hotkey_ss58=wallet.hotkey.ss58_address,
+ netuid=netuid,
+ )
if not neuron.is_null:
logging.debug(
f"Wallet {wallet} is already registered on subnet {neuron.netuid} with uid{neuron.uid}."
@@ -500,7 +97,7 @@ async def get_neuron_for_pubkey_and_subnet():
# Attempt rolling registration.
attempts = 1
- pow_result: Optional[POWSolution]
+ pow_result: Optional["POWSolution"]
while True:
logging.info(
f":satellite: Registering... ({attempts}/{max_allowed_attempts})"
@@ -509,7 +106,7 @@ async def get_neuron_for_pubkey_and_subnet():
if cuda:
if not torch.cuda.is_available():
return False
- pow_result = await create_pow(
+ pow_result = await create_pow_async(
subtensor,
wallet,
netuid,
@@ -522,7 +119,7 @@ async def get_neuron_for_pubkey_and_subnet():
log_verbose=log_verbose,
)
else:
- pow_result = await create_pow(
+ pow_result = await create_pow_async(
subtensor,
wallet,
netuid,
@@ -536,8 +133,8 @@ async def get_neuron_for_pubkey_and_subnet():
# pow failed
if not pow_result:
# might be registered already on this subnet
- is_registered = await is_hotkey_registered(
- subtensor, netuid=netuid, hotkey_ss58=wallet.hotkey.ss58_address
+ is_registered = await subtensor.is_hotkey_registered(
+ netuid=netuid, hotkey_ss58=wallet.hotkey.ss58_address
)
if is_registered:
logging.error(
@@ -549,7 +146,7 @@ async def get_neuron_for_pubkey_and_subnet():
else:
logging.info(":satellite: Submitting POW...")
# check if pow result is still valid
- while not await pow_result.is_stale(subtensor=subtensor):
+ while not await pow_result.is_stale_async(subtensor=subtensor):
call = await subtensor.substrate.compose_call(
call_module="SubtensorModule",
call_function="register",
@@ -597,10 +194,8 @@ async def get_neuron_for_pubkey_and_subnet():
# Successful registration, final check for neuron and pubkey
if success:
logging.info(":satellite: Checking Registration status...")
- is_registered = await is_hotkey_registered(
- subtensor,
- netuid=netuid,
- hotkey_ss58=wallet.hotkey.ss58_address,
+ is_registered = await subtensor.is_hotkey_registered(
+ netuid=netuid, hotkey_ss58=wallet.hotkey.ss58_address
)
if is_registered:
logging.success(
@@ -633,7 +228,7 @@ async def get_neuron_for_pubkey_and_subnet():
async def run_faucet_extrinsic(
subtensor: "AsyncSubtensor",
- wallet: Wallet,
+ wallet: "Wallet",
wait_for_inclusion: bool = False,
wait_for_finalization: bool = True,
max_allowed_attempts: int = 3,
@@ -684,12 +279,14 @@ async def run_faucet_extrinsic(
while True:
try:
pow_result = None
- while pow_result is None or await pow_result.is_stale(subtensor=subtensor):
+ while pow_result is None or await pow_result.is_stale_async(
+ subtensor=subtensor
+ ):
# Solve latest POW.
if cuda:
if not torch.cuda.is_available():
return False, "CUDA is not available."
- pow_result: Optional[POWSolution] = await create_pow(
+ pow_result = await create_pow_async(
subtensor,
wallet,
-1,
@@ -702,7 +299,7 @@ async def run_faucet_extrinsic(
log_verbose=log_verbose,
)
else:
- pow_result: Optional[POWSolution] = await create_pow(
+ pow_result = await create_pow_async(
subtensor,
wallet,
-1,
@@ -734,7 +331,7 @@ async def run_faucet_extrinsic(
await response.process_events()
if not await response.is_success:
logging.error(
- f":cross_mark: Failed: {format_error_message(await response.error_message, subtensor.substrate)}"
+ f":cross_mark: Failed: {format_error_message(error_message=await response.error_message, substrate=subtensor.substrate)}"
)
if attempts == max_allowed_attempts:
raise MaxAttemptsException
@@ -766,794 +363,3 @@ async def run_faucet_extrinsic(
except MaxAttemptsException:
return False, f"Max attempts reached: {max_allowed_attempts}"
-
-
-async def _check_for_newest_block_and_update(
- subtensor: "AsyncSubtensor",
- netuid: int,
- old_block_number: int,
- hotkey_bytes: bytes,
- curr_diff: Array,
- curr_block: Array,
- curr_block_num: Value,
- update_curr_block: "Callable",
- check_block: Lock,
- solvers: list[_Solver],
- curr_stats: "RegistrationStatistics",
-) -> int:
- """
- Checks for a new block and updates the current block information if a new block is found.
-
- Args:
- subtensor: The subtensor object to use for getting the current block.
- netuid: The netuid to use for retrieving the difficulty.
- old_block_number: The old block number to check against.
- hotkey_bytes: The bytes of the hotkey's pubkey.
- curr_diff: The current difficulty as a multiprocessing array.
- curr_block: Where the current block is stored as a multiprocessing array.
- curr_block_num: Where the current block number is stored as a multiprocessing value.
- update_curr_block: A function that updates the current block.
- check_block: A mp lock that is used to check for a new block.
- solvers: A list of solvers to update the current block for.
- curr_stats: The current registration statistics to update.
-
- Returns:
- The current block number.
- """
- block_number = await subtensor.substrate.get_block_number(None)
- if block_number != old_block_number:
- old_block_number = block_number
- # update block information
- block_number, difficulty, block_hash = await _get_block_with_retry(
- subtensor=subtensor, netuid=netuid
- )
- block_bytes = bytes.fromhex(block_hash[2:])
-
- update_curr_block(
- curr_diff,
- curr_block,
- curr_block_num,
- block_number,
- block_bytes,
- difficulty,
- hotkey_bytes,
- check_block,
- )
- # Set new block events for each solver
-
- for worker in solvers:
- worker.new_block_event.set()
-
- # update stats
- curr_stats.block_number = block_number
- curr_stats.block_hash = block_hash
- curr_stats.difficulty = difficulty
-
- return old_block_number
-
-
-async def _block_solver(
- subtensor: "AsyncSubtensor",
- wallet: Wallet,
- num_processes: int,
- netuid: int,
- dev_id: list[int],
- tpb: int,
- update_interval: int,
- curr_block,
- curr_block_num,
- curr_diff,
- n_samples,
- alpha_,
- output_in_place,
- log_verbose,
- cuda: bool,
-):
- """Shared code used by the Solvers to solve the POW solution."""
- limit = int(math.pow(2, 256)) - 1
-
- # Establish communication queues
- # See the _Solver class for more information on the queues.
- stop_event = Event()
- stop_event.clear()
-
- solution_queue = Queue()
- finished_queues = [Queue() for _ in range(num_processes)]
- check_block = Lock()
-
- hotkey_bytes = (
- wallet.coldkeypub.public_key if netuid == -1 else wallet.hotkey.public_key
- )
-
- if cuda:
- # Create a worker per CUDA device
- num_processes = len(dev_id)
- solvers = [
- _CUDASolver(
- i,
- num_processes,
- update_interval,
- finished_queues[i],
- solution_queue,
- stop_event,
- curr_block,
- curr_block_num,
- curr_diff,
- check_block,
- limit,
- dev_id[i],
- tpb,
- )
- for i in range(num_processes)
- ]
- else:
- # Start consumers
- solvers = [
- _Solver(
- i,
- num_processes,
- update_interval,
- finished_queues[i],
- solution_queue,
- stop_event,
- curr_block,
- curr_block_num,
- curr_diff,
- check_block,
- limit,
- )
- for i in range(num_processes)
- ]
-
- # Get first block
- block_number, difficulty, block_hash = await _get_block_with_retry(
- subtensor=subtensor, netuid=netuid
- )
-
- block_bytes = bytes.fromhex(block_hash[2:])
- old_block_number = block_number
- # Set to current block
- _update_curr_block(
- curr_diff,
- curr_block,
- curr_block_num,
- block_number,
- block_bytes,
- difficulty,
- hotkey_bytes,
- check_block,
- )
-
- # Set new block events for each solver to start at the initial block
- for worker in solvers:
- worker.new_block_event.set()
-
- for worker in solvers:
- worker.start() # start the solver processes
-
- start_time = time.time() # time that the registration started
- time_last = start_time # time that the last work blocks completed
-
- curr_stats = RegistrationStatistics(
- time_spent_total=0.0,
- time_average=0.0,
- rounds_total=0,
- time_spent=0.0,
- hash_rate_perpetual=0.0,
- hash_rate=0.0,
- difficulty=difficulty,
- block_number=block_number,
- block_hash=block_hash,
- )
-
- start_time_perpetual = time.time()
-
- logger = RegistrationStatisticsLogger(output_in_place=output_in_place)
- logger.start()
-
- solution = None
-
- hash_rates = [0] * n_samples # The last n true hash_rates
- weights = [alpha_**i for i in range(n_samples)] # weights decay by alpha
-
- timeout = 0.15 if cuda else 0.15
- while netuid == -1 or not await is_hotkey_registered(
- subtensor, netuid, wallet.hotkey.ss58_address
- ):
- # Wait until a solver finds a solution
- try:
- solution = solution_queue.get(block=True, timeout=timeout)
- if solution is not None:
- break
- except Empty:
- # No solution found, try again
- pass
-
- # check for new block
- old_block_number = await _check_for_newest_block_and_update(
- subtensor=subtensor,
- netuid=netuid,
- hotkey_bytes=hotkey_bytes,
- old_block_number=old_block_number,
- curr_diff=curr_diff,
- curr_block=curr_block,
- curr_block_num=curr_block_num,
- curr_stats=curr_stats,
- update_curr_block=_update_curr_block,
- check_block=check_block,
- solvers=solvers,
- )
-
- num_time = 0
- for finished_queue in finished_queues:
- try:
- finished_queue.get(timeout=0.1)
- num_time += 1
-
- except Empty:
- continue
-
- time_now = time.time() # get current time
- time_since_last = time_now - time_last # get time since last work block(s)
- if num_time > 0 and time_since_last > 0.0:
- # create EWMA of the hash_rate to make measure more robust
-
- if cuda:
- hash_rate_ = (num_time * tpb * update_interval) / time_since_last
- else:
- hash_rate_ = (num_time * update_interval) / time_since_last
- hash_rates.append(hash_rate_)
- hash_rates.pop(0) # remove the 0th data point
- curr_stats.hash_rate = sum(
- [hash_rates[i] * weights[i] for i in range(n_samples)]
- ) / (sum(weights))
-
- # update time last to now
- time_last = time_now
-
- curr_stats.time_average = (
- curr_stats.time_average * curr_stats.rounds_total
- + curr_stats.time_spent
- ) / (curr_stats.rounds_total + num_time)
- curr_stats.rounds_total += num_time
-
- # Update stats
- curr_stats.time_spent = time_since_last
- new_time_spent_total = time_now - start_time_perpetual
- if cuda:
- curr_stats.hash_rate_perpetual = (
- curr_stats.rounds_total * (tpb * update_interval)
- ) / new_time_spent_total
- else:
- curr_stats.hash_rate_perpetual = (
- curr_stats.rounds_total * update_interval
- ) / new_time_spent_total
- curr_stats.time_spent_total = new_time_spent_total
-
- # Update the logger
- logger.update(curr_stats, verbose=log_verbose)
-
- # exited while, solution contains the nonce or wallet is registered
- stop_event.set() # stop all other processes
- logger.stop()
-
- # terminate and wait for all solvers to exit
- _terminate_workers_and_wait_for_exit(solvers)
-
- return solution
-
-
-async def _solve_for_difficulty_fast_cuda(
- subtensor: "AsyncSubtensor",
- wallet: Wallet,
- netuid: int,
- output_in_place: bool = True,
- update_interval: int = 50_000,
- tpb: int = 512,
- dev_id: Union[list[int], int] = 0,
- n_samples: int = 10,
- alpha_: float = 0.80,
- log_verbose: bool = False,
-) -> Optional["POWSolution"]:
- """
- Solves the registration fast using CUDA
-
- Args:
- subtensor: The subtensor node to grab blocks
- wallet: The wallet to register
- netuid: The netuid of the subnet to register to.
- output_in_place: If true, prints the output in place, otherwise prints to new lines
- update_interval: The number of nonces to try before checking for more blocks
- tpb: The number of threads per block. CUDA param that should match the GPU capability
- dev_id: The CUDA device IDs to execute the registration on, either a single device or a list of devices
- n_samples: The number of samples of the hash_rate to keep for the EWMA
- alpha_: The alpha for the EWMA for the hash_rate calculation
- log_verbose: If true, prints more verbose logging of the registration metrics.
-
- Note:
- The hash rate is calculated as an exponentially weighted moving average in order to make the measure more robust.
- """
- if isinstance(dev_id, int):
- dev_id = [dev_id]
- elif dev_id is None:
- dev_id = [0]
-
- if update_interval is None:
- update_interval = 50_000
-
- if not torch.cuda.is_available():
- raise Exception("CUDA not available")
-
- # Set mp start to use spawn so CUDA doesn't complain
- with _UsingSpawnStartMethod(force=True):
- curr_block, curr_block_num, curr_diff = _CUDASolver.create_shared_memory()
-
- solution = await _block_solver(
- subtensor=subtensor,
- wallet=wallet,
- num_processes=None,
- netuid=netuid,
- dev_id=dev_id,
- tpb=tpb,
- update_interval=update_interval,
- curr_block=curr_block,
- curr_block_num=curr_block_num,
- curr_diff=curr_diff,
- n_samples=n_samples,
- alpha_=alpha_,
- output_in_place=output_in_place,
- log_verbose=log_verbose,
- cuda=True,
- )
-
- return solution
-
-
-async def _solve_for_difficulty_fast(
- subtensor,
- wallet: Wallet,
- netuid: int,
- output_in_place: bool = True,
- num_processes: Optional[int] = None,
- update_interval: Optional[int] = None,
- n_samples: int = 10,
- alpha_: float = 0.80,
- log_verbose: bool = False,
-) -> Optional[POWSolution]:
- """
- Solves the POW for registration using multiprocessing.
-
- Args:
- subtensor: Subtensor to connect to for block information and to submit.
- wallet: wallet to use for registration.
- netuid: The netuid of the subnet to register to.
- output_in_place: If true, prints the status in place. Otherwise, prints the status on a new line.
- num_processes: Number of processes to use.
- update_interval: Number of nonces to solve before updating block information.
- n_samples: The number of samples of the hash_rate to keep for the EWMA
- alpha_: The alpha for the EWMA for the hash_rate calculation
- log_verbose: If true, prints more verbose logging of the registration metrics.
-
- Notes:
- The hash rate is calculated as an exponentially weighted moving average in order to make the measure more robust.
- We can also modify the update interval to do smaller blocks of work, while still updating the block information after a different number of nonces, to increase the transparency of the process while still keeping the speed.
- """
- if not num_processes:
- # get the number of allowed processes for this process
- num_processes = min(1, get_cpu_count())
-
- if update_interval is None:
- update_interval = 50_000
-
- curr_block, curr_block_num, curr_diff = _Solver.create_shared_memory()
-
- solution = await _block_solver(
- subtensor=subtensor,
- wallet=wallet,
- num_processes=num_processes,
- netuid=netuid,
- dev_id=None,
- tpb=None,
- update_interval=update_interval,
- curr_block=curr_block,
- curr_block_num=curr_block_num,
- curr_diff=curr_diff,
- n_samples=n_samples,
- alpha_=alpha_,
- output_in_place=output_in_place,
- log_verbose=log_verbose,
- cuda=False,
- )
-
- return solution
-
-
-def _terminate_workers_and_wait_for_exit(
- workers: list[Union[Process, Queue_Type]],
-) -> None:
- for worker in workers:
- if isinstance(worker, Queue_Type):
- worker.join_thread()
- else:
- try:
- worker.join(3.0)
- except subprocess.TimeoutExpired:
- worker.terminate()
- try:
- worker.close()
- except ValueError:
- worker.terminate()
-
-
-# TODO verify this works with async
-@backoff.on_exception(backoff.constant, Exception, interval=1, max_tries=3)
-async def _get_block_with_retry(
- subtensor: "AsyncSubtensor", netuid: int
-) -> tuple[int, int, bytes]:
- """
- Gets the current block number, difficulty, and block hash from the substrate node.
-
- Args:
- subtensor: The subtensor object to use to get the block number, difficulty, and block hash.
- netuid: The netuid of the network to get the block number, difficulty, and block hash from.
-
- Returns:
- The current block number, difficulty of the subnet, block hash
-
- Raises:
- Exception: If the block hash is None.
- ValueError: If the difficulty is None.
- """
- block_number = await subtensor.substrate.get_block_number(None)
- block_hash = await subtensor.substrate.get_block_hash(
- block_number
- ) # TODO check if I need to do all this
- try:
- difficulty = (
- 1_000_000
- if netuid == -1
- else int(
- await subtensor.get_hyperparameter(
- param_name="Difficulty", netuid=netuid, block_hash=block_hash
- )
- )
- )
- except TypeError:
- raise ValueError("Chain error. Difficulty is None")
- except SubstrateRequestException:
- raise Exception(
- "Network error. Could not connect to substrate to get block hash"
- )
- return block_number, difficulty, block_hash
-
-
-def _registration_diff_unpack(packed_diff: Array) -> int:
- """Unpacks the packed two 32-bit integers into one 64-bit integer. Little endian."""
- return int(packed_diff[0] << 32 | packed_diff[1])
-
-
-def _registration_diff_pack(diff: int, packed_diff: Array):
- """Packs the difficulty into two 32-bit integers. Little endian."""
- packed_diff[0] = diff >> 32
- packed_diff[1] = diff & 0xFFFFFFFF # low 32 bits
-
-
-class _UsingSpawnStartMethod:
- def __init__(self, force: bool = False):
- self._old_start_method = None
- self._force = force
-
- def __enter__(self):
- self._old_start_method = mp.get_start_method(allow_none=True)
- if self._old_start_method is None:
- self._old_start_method = "spawn" # default to spawn
-
- mp.set_start_method("spawn", force=self._force)
-
- def __exit__(self, *args):
- # restore the old start method
- mp.set_start_method(self._old_start_method, force=True)
-
-
-async def create_pow(
- subtensor: "AsyncSubtensor",
- wallet: Wallet,
- netuid: int,
- output_in_place: bool = True,
- cuda: bool = False,
- dev_id: Union[list[int], int] = 0,
- tpb: int = 256,
- num_processes: int = None,
- update_interval: int = None,
- log_verbose: bool = False,
-) -> Optional[dict[str, Any]]:
- """
- Creates a proof of work for the given subtensor and wallet.
-
- Args:
- subtensor: The subtensor to create a proof of work for.
- wallet: The wallet to create a proof of work for.
- netuid: The netuid for the subnet to create a proof of work for.
- output_in_place: If true, prints the progress of the proof of work to the console in-place. Meaning the progress is printed on the same lines.
- cuda: If true, uses CUDA to solve the proof of work.
- dev_id: The CUDA device id(s) to use. If cuda is true and dev_id is a list, then multiple CUDA devices will be used to solve the proof of work.
- tpb: The number of threads per block to use when solving the proof of work. Should be a multiple of 32.
- num_processes: The number of processes to use when solving the proof of work. If None, then the number of processes is equal to the number of CPU cores.
- update_interval: The number of nonces to run before checking for a new block.
- log_verbose: If true, prints the progress of the proof of work more verbosely.
-
- Returns:
- The proof of work solution or None if the wallet is already registered or there is a different error.
-
- Raises:
- ValueError: If the subnet does not exist.
- """
- if netuid != -1:
- if not await subtensor.subnet_exists(netuid=netuid):
- raise ValueError(f"Subnet {netuid} does not exist")
-
- if cuda:
- solution: Optional[POWSolution] = await _solve_for_difficulty_fast_cuda(
- subtensor,
- wallet,
- netuid=netuid,
- output_in_place=output_in_place,
- dev_id=dev_id,
- tpb=tpb,
- update_interval=update_interval,
- log_verbose=log_verbose,
- )
- else:
- solution: Optional[POWSolution] = await _solve_for_difficulty_fast(
- subtensor,
- wallet,
- netuid=netuid,
- output_in_place=output_in_place,
- num_processes=num_processes,
- update_interval=update_interval,
- log_verbose=log_verbose,
- )
-
- return solution
-
-
-def _solve_for_nonce_block_cuda(
- nonce_start: int,
- update_interval: int,
- block_and_hotkey_hash_bytes: bytes,
- difficulty: int,
- limit: int,
- block_number: int,
- dev_id: int,
- tpb: int,
-) -> Optional[POWSolution]:
- """
- Tries to solve the POW on a CUDA device for a block of nonces (nonce_start, nonce_start + update_interval * tpb
- """
- solution, seal = solve_cuda(
- nonce_start,
- update_interval,
- tpb,
- block_and_hotkey_hash_bytes,
- difficulty,
- limit,
- dev_id,
- )
-
- if solution != -1:
- # Check if solution is valid (i.e. not -1)
- return POWSolution(solution, block_number, difficulty, seal)
-
- return None
-
-
-def _solve_for_nonce_block(
- nonce_start: int,
- nonce_end: int,
- block_and_hotkey_hash_bytes: bytes,
- difficulty: int,
- limit: int,
- block_number: int,
-) -> Optional[POWSolution]:
- """
- Tries to solve the POW for a block of nonces (nonce_start, nonce_end)
- """
- for nonce in range(nonce_start, nonce_end):
- # Create seal.
- seal = _create_seal_hash(block_and_hotkey_hash_bytes, nonce)
-
- # Check if seal meets difficulty
- if _seal_meets_difficulty(seal, difficulty, limit):
- # Found a solution, save it.
- return POWSolution(nonce, block_number, difficulty, seal)
-
- return None
-
-
-class CUDAException(Exception):
- """An exception raised when an error occurs in the CUDA environment."""
-
-
-def _hex_bytes_to_u8_list(hex_bytes: bytes):
- hex_chunks = [int(hex_bytes[i : i + 2], 16) for i in range(0, len(hex_bytes), 2)]
- return hex_chunks
-
-
-def _create_seal_hash(block_and_hotkey_hash_bytes: bytes, nonce: int) -> bytes:
- """
- Create a cryptographic seal hash from the given block and hotkey hash bytes and nonce.
-
- This function generates a seal hash by combining the given block and hotkey hash bytes with a nonce.
- It first converts the nonce to a byte representation, then concatenates it with the first 64 hex characters of the block and hotkey hash bytes. The result is then hashed using SHA-256 followed by the Keccak-256 algorithm to produce the final seal hash.
-
- Args:
- block_and_hotkey_hash_bytes (bytes): The combined hash bytes of the block and hotkey.
- nonce (int): The nonce value used for hashing.
-
- Returns:
- The resulting seal hash.
- """
- nonce_bytes = binascii.hexlify(nonce.to_bytes(8, "little"))
- pre_seal = nonce_bytes + binascii.hexlify(block_and_hotkey_hash_bytes)[:64]
- seal_sh256 = hashlib.sha256(bytearray(_hex_bytes_to_u8_list(pre_seal))).digest()
- kec = keccak.new(digest_bits=256)
- seal = kec.update(seal_sh256).digest()
- return seal
-
-
-def _seal_meets_difficulty(seal: bytes, difficulty: int, limit: int) -> bool:
- """Determines if a seal meets the specified difficulty"""
- seal_number = int.from_bytes(seal, "big")
- product = seal_number * difficulty
- return product < limit
-
-
-def _hash_block_with_hotkey(block_bytes: bytes, hotkey_bytes: bytes) -> bytes:
- """Hashes the block with the hotkey using Keccak-256 to get 32 bytes"""
- kec = keccak.new(digest_bits=256)
- kec = kec.update(bytearray(block_bytes + hotkey_bytes))
- block_and_hotkey_hash_bytes = kec.digest()
- return block_and_hotkey_hash_bytes
-
-
-def _update_curr_block(
- curr_diff: Array,
- curr_block: Array,
- curr_block_num: Value,
- block_number: int,
- block_bytes: bytes,
- diff: int,
- hotkey_bytes: bytes,
- lock: Lock,
-):
- """
- Update the current block data with the provided block information and difficulty.
-
- This function updates the current block and its difficulty in a thread-safe manner. It sets the current block
- number, hashes the block with the hotkey, updates the current block bytes, and packs the difficulty.
-
- curr_diff: Shared array to store the current difficulty.
- curr_block: Shared array to store the current block data.
- curr_block_num: Shared value to store the current block number.
- block_number: The block number to set as the current block number.
- block_bytes: The block data bytes to be hashed with the hotkey.
- diff: The difficulty value to be packed into the current difficulty array.
- hotkey_bytes: The hotkey bytes used for hashing the block.
- lock: A lock to ensure thread-safe updates.
- """
- with lock:
- curr_block_num.value = block_number
- # Hash the block with the hotkey
- block_and_hotkey_hash_bytes = _hash_block_with_hotkey(block_bytes, hotkey_bytes)
- for i in range(32):
- curr_block[i] = block_and_hotkey_hash_bytes[i]
- _registration_diff_pack(diff, curr_diff)
-
-
-def get_cpu_count() -> int:
- try:
- return len(os.sched_getaffinity(0))
- except AttributeError:
- # macOS does not have sched_getaffinity
- return os.cpu_count()
-
-
-@dataclass
-class RegistrationStatistics:
- """Statistics for a registration."""
-
- time_spent_total: float
- rounds_total: int
- time_average: float
- time_spent: float
- hash_rate_perpetual: float
- hash_rate: float
- difficulty: int
- block_number: int
- block_hash: bytes
-
-
-def solve_cuda(
- nonce_start: np.int64,
- update_interval: np.int64,
- tpb: int,
- block_and_hotkey_hash_bytes: bytes,
- difficulty: int,
- limit: int,
- dev_id: int = 0,
-) -> tuple[np.int64, bytes]:
- """
- Solves the PoW problem using CUDA.
-
- nonce_start: Starting nonce.
- update_interval: Number of nonces to solve before updating block information.
- tpb: Threads per block.
- block_and_hotkey_hash_bytes: Keccak(Bytes of the block hash + bytes of the hotkey) 64 bytes.
- difficulty: Difficulty of the PoW problem.
- limit: Upper limit of the nonce.
- dev_id: The CUDA device ID
-
- :return: (nonce, seal) corresponding to the solution. Returns -1 for nonce if no solution is found.
- """
-
- try:
- import cubit
- except ImportError:
- raise ImportError("Please install cubit")
-
- upper = int(limit // difficulty)
-
- upper_bytes = upper.to_bytes(32, byteorder="little", signed=False)
-
- # Call cython function
- # int blockSize, uint64 nonce_start, uint64 update_interval, const unsigned char[:] limit,
- # const unsigned char[:] block_bytes, int dev_id
- block_and_hotkey_hash_hex = binascii.hexlify(block_and_hotkey_hash_bytes)[:64]
-
- solution = cubit.solve_cuda(
- tpb,
- nonce_start,
- update_interval,
- upper_bytes,
- block_and_hotkey_hash_hex,
- dev_id,
- ) # 0 is first GPU
- seal = None
- if solution != -1:
- seal = _create_seal_hash(block_and_hotkey_hash_hex, solution)
- if _seal_meets_difficulty(seal, difficulty, limit):
- return solution, seal
- else:
- return -1, b"\x00" * 32
-
- return solution, seal
-
-
-def reset_cuda():
- """
- Resets the CUDA environment.
- """
- try:
- import cubit
- except ImportError:
- raise ImportError("Please install cubit")
-
- cubit.reset_cuda()
-
-
-def log_cuda_errors() -> str:
- """
- Logs any CUDA errors.
- """
- try:
- import cubit
- except ImportError:
- raise ImportError("Please install cubit")
-
- f = io.StringIO()
- with redirect_stdout(f):
- cubit.log_cuda_errors()
-
- s = f.getvalue()
-
- return s
diff --git a/bittensor/core/extrinsics/registration.py b/bittensor/core/extrinsics/registration.py
index 57bf9e7a56..7af18c24c7 100644
--- a/bittensor/core/extrinsics/registration.py
+++ b/bittensor/core/extrinsics/registration.py
@@ -1,19 +1,10 @@
-# The MIT License (MIT)
-# Copyright © 2024 Opentensor Foundation
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-#
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
+"""
+This module provides functionalities for registering a wallet with the subtensor network using Proof-of-Work (PoW).
+
+Extrinsics:
+- register_extrinsic: Registers the wallet to the subnet.
+- burned_register_extrinsic: Registers the wallet to chain by recycling TAO.
+"""
import time
from typing import Union, Optional, TYPE_CHECKING
@@ -23,16 +14,21 @@
from bittensor.utils.btlogging import logging
from bittensor.utils.networking import ensure_connected
from bittensor.utils.registration import (
- POWSolution,
create_pow,
torch,
log_no_torch_error,
)
-# For annotation purposes
+# For annotation and lazy import purposes
if TYPE_CHECKING:
- from bittensor.core.subtensor import Subtensor
+ import torch
from bittensor_wallet import Wallet
+ from bittensor.core.subtensor import Subtensor
+ from bittensor.utils.registration import POWSolution
+else:
+ from bittensor.utils.registration.pow import LazyLoadedTorch
+
+ torch = LazyLoadedTorch()
@ensure_connected
@@ -47,6 +43,7 @@ def _do_pow_register(
"""Sends a (POW) register extrinsic to the chain.
Args:
+ self (bittensor.core.subtensor.Subtensor): The subtensor to send the extrinsic to.
netuid (int): The subnet to register on.
wallet (bittensor.wallet): The wallet to register.
pow_result (POWSolution): The PoW result to register.
@@ -164,7 +161,7 @@ def register_extrinsic(
if cuda:
if not torch.cuda.is_available():
return False
- pow_result: Optional[POWSolution] = create_pow(
+ pow_result: Optional["POWSolution"] = create_pow(
subtensor,
wallet,
netuid,
@@ -177,7 +174,7 @@ def register_extrinsic(
log_verbose=log_verbose,
)
else:
- pow_result: Optional[POWSolution] = create_pow(
+ pow_result: Optional["POWSolution"] = create_pow(
subtensor,
wallet,
netuid,
diff --git a/bittensor/core/subtensor.py b/bittensor/core/subtensor.py
index efaff369a4..6416d9971d 100644
--- a/bittensor/core/subtensor.py
+++ b/bittensor/core/subtensor.py
@@ -16,8 +16,8 @@
# DEALINGS IN THE SOFTWARE.
"""
-The ``bittensor.core.subtensor.Subtensor`` module in Bittensor serves as a crucial interface for interacting with the Bittensor
-blockchain, facilitating a range of operations essential for the decentralized machine learning network.
+The ``bittensor.core.subtensor.Subtensor`` module in Bittensor serves as a crucial interface for interacting with the
+Bittensor blockchain, facilitating a range of operations essential for the decentralized machine learning network.
"""
import argparse
diff --git a/bittensor/utils/registration.py b/bittensor/utils/registration.py
index f190f668da..e69de29bb2 100644
--- a/bittensor/utils/registration.py
+++ b/bittensor/utils/registration.py
@@ -1,1120 +0,0 @@
-# The MIT License (MIT)
-# Copyright © 2024 Opentensor Foundation
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-#
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
-
-import binascii
-import dataclasses
-import functools
-import hashlib
-import math
-import multiprocessing
-import os
-import random
-import subprocess
-import time
-from datetime import timedelta
-from multiprocessing.queues import Queue as QueueType
-from queue import Empty, Full
-from typing import Any, Callable, Optional, Union, TYPE_CHECKING
-
-import numpy
-from Crypto.Hash import keccak
-from retry import retry
-from rich import console as rich_console, status as rich_status
-from rich.console import Console
-
-from bittensor.utils.btlogging import logging
-from bittensor.utils.formatting import get_human_readable, millify
-from bittensor.utils.register_cuda import solve_cuda
-
-
-def use_torch() -> bool:
- """Force the use of torch over numpy for certain operations."""
- return True if os.getenv("USE_TORCH") == "1" else False
-
-
-def legacy_torch_api_compat(func):
- """
- Convert function operating on numpy Input&Output to legacy torch Input&Output API if `use_torch()` is True.
-
- Args:
- func (function): Function with numpy Input/Output to be decorated.
-
- Returns:
- decorated (function): Decorated function.
- """
-
- @functools.wraps(func)
- def decorated(*args, **kwargs):
- if use_torch():
- # if argument is a Torch tensor, convert it to numpy
- args = [
- arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg
- for arg in args
- ]
- kwargs = {
- key: value.cpu().numpy() if isinstance(value, torch.Tensor) else value
- for key, value in kwargs.items()
- }
- ret = func(*args, **kwargs)
- if use_torch():
- # if return value is a numpy array, convert it to Torch tensor
- if isinstance(ret, numpy.ndarray):
- ret = torch.from_numpy(ret)
- return ret
-
- return decorated
-
-
-@functools.cache
-def _get_real_torch():
- try:
- import torch as _real_torch
- except ImportError:
- _real_torch = None
- return _real_torch
-
-
-def log_no_torch_error():
- logging.error(
- "This command requires torch. You can install torch for bittensor"
- ' with `pip install bittensor[torch]` or `pip install ".[torch]"`'
- " if installing from source, and then run the command with USE_TORCH=1 {command}"
- )
-
-
-class LazyLoadedTorch:
- """A lazy-loading proxy for the torch module."""
-
- def __bool__(self):
- return bool(_get_real_torch())
-
- def __getattr__(self, name):
- if real_torch := _get_real_torch():
- return getattr(real_torch, name)
- else:
- log_no_torch_error()
- raise ImportError("torch not installed")
-
-
-if TYPE_CHECKING:
- import torch
- from bittensor.core.subtensor import Subtensor
- from bittensor_wallet import Wallet
-else:
- torch = LazyLoadedTorch()
-
-
-def _hex_bytes_to_u8_list(hex_bytes: bytes):
- hex_chunks = [int(hex_bytes[i : i + 2], 16) for i in range(0, len(hex_bytes), 2)]
- return hex_chunks
-
-
-def _create_seal_hash(block_and_hotkey_hash_bytes: bytes, nonce: int) -> bytes:
- """Create a seal hash for a given block and nonce."""
- nonce_bytes = binascii.hexlify(nonce.to_bytes(8, "little"))
- pre_seal = nonce_bytes + binascii.hexlify(block_and_hotkey_hash_bytes)[:64]
- seal_sh256 = hashlib.sha256(bytearray(_hex_bytes_to_u8_list(pre_seal))).digest()
- kec = keccak.new(digest_bits=256)
- seal = kec.update(seal_sh256).digest()
- return seal
-
-
-def _seal_meets_difficulty(seal: bytes, difficulty: int, limit: int):
- """Check if the seal meets the given difficulty criteria."""
- seal_number = int.from_bytes(seal, "big")
- product = seal_number * difficulty
- return product < limit
-
-
-@dataclasses.dataclass
-class POWSolution:
- """A solution to the registration PoW problem."""
-
- nonce: int
- block_number: int
- difficulty: int
- seal: bytes
-
- def is_stale(self, subtensor: "Subtensor") -> bool:
- """
- Returns True if the POW is stale.
-
- This means the block the POW is solved for is within 3 blocks of the current block.
- """
- return self.block_number < subtensor.get_current_block() - 3
-
-
-class _UsingSpawnStartMethod:
- def __init__(self, force: bool = False):
- self._old_start_method = None
- self._force = force
-
- def __enter__(self):
- self._old_start_method = multiprocessing.get_start_method(allow_none=True)
- if self._old_start_method is None:
- self._old_start_method = "spawn" # default to spawn
-
- multiprocessing.set_start_method("spawn", force=self._force)
-
- def __exit__(self, *args):
- # restore the old start method
- multiprocessing.set_start_method(self._old_start_method, force=True)
-
-
-class _SolverBase(multiprocessing.Process):
- """
- A process that solves the registration PoW problem.
-
- Args:
- proc_num (int): The number of the process being created.
- num_proc (int): The total number of processes running.
- update_interval (int): The number of nonces to try to solve before checking for a new block.
- finished_queue (multiprocessing.Queue): The queue to put the process number when a process finishes each update_interval. Used for calculating the average time per update_interval across all processes.
- solution_queue (multiprocessing.Queue): The queue to put the solution the process has found during the pow solve.
- newBlockEvent (multiprocessing.Event): The event to set by the main process when a new block is finalized in the network. The solver process will check for the event after each update_interval. The solver process will get the new block hash and difficulty and start solving for a new nonce.
- stopEvent (multiprocessing.Event): The event to set by the main process when all the solver processes should stop. The solver process will check for the event after each update_interval. The solver process will stop when the event is set. Used to stop the solver processes when a solution is found.
- curr_block (multiprocessing.Array): The array containing this process's current block hash. The main process will set the array to the new block hash when a new block is finalized in the network. The solver process will get the new block hash from this array when newBlockEvent is set.
- curr_block_num (multiprocessing.Value): The value containing this process's current block number. The main process will set the value to the new block number when a new block is finalized in the network. The solver process will get the new block number from this value when newBlockEvent is set.
- curr_diff (multiprocessing.Array): The array containing this process's current difficulty. The main process will set the array to the new difficulty when a new block is finalized in the network. The solver process will get the new difficulty from this array when newBlockEvent is set.
- check_block (multiprocessing.Lock): The lock to prevent this process from getting the new block data while the main process is updating the data.
- limit (int): The limit of the pow solve for a valid solution.
- """
-
- proc_num: int
- num_proc: int
- update_interval: int
- finished_queue: "multiprocessing.Queue"
- solution_queue: "multiprocessing.Queue"
- newBlockEvent: "multiprocessing.Event"
- stopEvent: "multiprocessing.Event"
- hotkey_bytes: bytes
- curr_block: "multiprocessing.Array"
- curr_block_num: "multiprocessing.Value"
- curr_diff: "multiprocessing.Array"
- check_block: "multiprocessing.Lock"
- limit: int
-
- def __init__(
- self,
- proc_num,
- num_proc,
- update_interval,
- finished_queue,
- solution_queue,
- stopEvent,
- curr_block,
- curr_block_num,
- curr_diff,
- check_block,
- limit,
- ):
- multiprocessing.Process.__init__(self, daemon=True)
- self.proc_num = proc_num
- self.num_proc = num_proc
- self.update_interval = update_interval
- self.finished_queue = finished_queue
- self.solution_queue = solution_queue
- self.newBlockEvent = multiprocessing.Event()
- self.newBlockEvent.clear()
- self.curr_block = curr_block
- self.curr_block_num = curr_block_num
- self.curr_diff = curr_diff
- self.check_block = check_block
- self.stopEvent = stopEvent
- self.limit = limit
-
- def run(self):
- raise NotImplementedError("_SolverBase is an abstract class")
-
- @staticmethod
- def create_shared_memory() -> (
- tuple["multiprocessing.Array", "multiprocessing.Value", "multiprocessing.Array"]
- ):
- """Creates shared memory for the solver processes to use."""
- curr_block = multiprocessing.Array("h", 32, lock=True) # byte array
- curr_block_num = multiprocessing.Value("i", 0, lock=True) # int
- curr_diff = multiprocessing.Array("Q", [0, 0], lock=True) # [high, low]
-
- return curr_block, curr_block_num, curr_diff
-
-
-class _Solver(_SolverBase):
- def run(self):
- block_number: int
- block_and_hotkey_hash_bytes: bytes
- block_difficulty: int
- nonce_limit = int(math.pow(2, 64)) - 1
-
- # Start at random nonce
- nonce_start = random.randint(0, nonce_limit)
- nonce_end = nonce_start + self.update_interval
- while not self.stopEvent.is_set():
- if self.newBlockEvent.is_set():
- with self.check_block:
- block_number = self.curr_block_num.value
- block_and_hotkey_hash_bytes = bytes(self.curr_block)
- block_difficulty = _registration_diff_unpack(self.curr_diff)
-
- self.newBlockEvent.clear()
-
- # Do a block of nonces
- solution = _solve_for_nonce_block(
- nonce_start,
- nonce_end,
- block_and_hotkey_hash_bytes,
- block_difficulty,
- self.limit,
- block_number,
- )
- if solution is not None:
- self.solution_queue.put(solution)
-
- try:
- # Send time
- self.finished_queue.put_nowait(self.proc_num)
- except Full:
- pass
-
- nonce_start = random.randint(0, nonce_limit)
- nonce_start = nonce_start % nonce_limit
- nonce_end = nonce_start + self.update_interval
-
-
-class _CUDASolver(_SolverBase):
- dev_id: int
- tpb: int
-
- def __init__(
- self,
- proc_num,
- num_proc,
- update_interval,
- finished_queue,
- solution_queue,
- stopEvent,
- curr_block,
- curr_block_num,
- curr_diff,
- check_block,
- limit,
- dev_id: int,
- tpb: int,
- ):
- super().__init__(
- proc_num,
- num_proc,
- update_interval,
- finished_queue,
- solution_queue,
- stopEvent,
- curr_block,
- curr_block_num,
- curr_diff,
- check_block,
- limit,
- )
- self.dev_id = dev_id
- self.tpb = tpb
-
- def run(self):
- block_number: int = 0 # dummy value
- block_and_hotkey_hash_bytes: bytes = b"0" * 32 # dummy value
- block_difficulty: int = int(math.pow(2, 64)) - 1 # dummy value
- nonce_limit = int(math.pow(2, 64)) - 1 # U64MAX
-
- # Start at random nonce
- nonce_start = random.randint(0, nonce_limit)
- while not self.stopEvent.is_set():
- if self.newBlockEvent.is_set():
- with self.check_block:
- block_number = self.curr_block_num.value
- block_and_hotkey_hash_bytes = bytes(self.curr_block)
- block_difficulty = _registration_diff_unpack(self.curr_diff)
-
- self.newBlockEvent.clear()
-
- # Do a block of nonces
- solution = _solve_for_nonce_block_cuda(
- nonce_start,
- self.update_interval,
- block_and_hotkey_hash_bytes,
- block_difficulty,
- self.limit,
- block_number,
- self.dev_id,
- self.tpb,
- )
- if solution is not None:
- self.solution_queue.put(solution)
-
- try:
- # Signal that a nonce_block was finished using queue
- # send our proc_num
- self.finished_queue.put(self.proc_num)
- except Full:
- pass
-
- # increase nonce by number of nonces processed
- nonce_start += self.update_interval * self.tpb
- nonce_start = nonce_start % nonce_limit
-
-
-def _solve_for_nonce_block_cuda(
- nonce_start: int,
- update_interval: int,
- block_and_hotkey_hash_bytes: bytes,
- difficulty: int,
- limit: int,
- block_number: int,
- dev_id: int,
- tpb: int,
-) -> Optional["POWSolution"]:
- """Tries to solve the POW on a CUDA device for a block of nonces (nonce_start, nonce_start + update_interval * tpb"""
- solution, seal = solve_cuda(
- nonce_start,
- update_interval,
- tpb,
- block_and_hotkey_hash_bytes,
- difficulty,
- limit,
- dev_id,
- )
-
- if solution != -1:
- # Check if solution is valid (i.e. not -1)
- return POWSolution(solution, block_number, difficulty, seal)
-
- return None
-
-
-def _solve_for_nonce_block(
- nonce_start: int,
- nonce_end: int,
- block_and_hotkey_hash_bytes: bytes,
- difficulty: int,
- limit: int,
- block_number: int,
-) -> Optional["POWSolution"]:
- """Tries to solve the POW for a block of nonces (nonce_start, nonce_end)"""
- for nonce in range(nonce_start, nonce_end):
- # Create seal.
- seal = _create_seal_hash(block_and_hotkey_hash_bytes, nonce)
-
- # Check if seal meets difficulty
- if _seal_meets_difficulty(seal, difficulty, limit):
- # Found a solution, save it.
- return POWSolution(nonce, block_number, difficulty, seal)
-
- return None
-
-
-def _registration_diff_unpack(packed_diff: "multiprocessing.Array") -> int:
- """Unpacks the packed two 32-bit integers into one 64-bit integer. Little endian."""
- return int(packed_diff[0] << 32 | packed_diff[1])
-
-
-def _registration_diff_pack(diff: int, packed_diff: "multiprocessing.Array"):
- """Packs the difficulty into two 32-bit integers. Little endian."""
- packed_diff[0] = diff >> 32
- packed_diff[1] = diff & 0xFFFFFFFF # low 32 bits
-
-
-def _hash_block_with_hotkey(block_bytes: bytes, hotkey_bytes: bytes) -> bytes:
- """Hashes the block with the hotkey using Keccak-256 to get 32 bytes"""
- kec = keccak.new(digest_bits=256)
- kec = kec.update(bytearray(block_bytes + hotkey_bytes))
- block_and_hotkey_hash_bytes = kec.digest()
- return block_and_hotkey_hash_bytes
-
-
-def _update_curr_block(
- curr_diff: "multiprocessing.Array",
- curr_block: "multiprocessing.Array",
- curr_block_num: "multiprocessing.Value",
- block_number: int,
- block_bytes: bytes,
- diff: int,
- hotkey_bytes: bytes,
- lock: "multiprocessing.Lock",
-):
- """Updates the current block's information atomically using a lock."""
- with lock:
- curr_block_num.value = block_number
- # Hash the block with the hotkey
- block_and_hotkey_hash_bytes = _hash_block_with_hotkey(block_bytes, hotkey_bytes)
- for i in range(32):
- curr_block[i] = block_and_hotkey_hash_bytes[i]
- _registration_diff_pack(diff, curr_diff)
-
-
-def get_cpu_count() -> int:
- """Returns the number of CPUs in the system."""
- try:
- return len(os.sched_getaffinity(0))
- except AttributeError:
- # OSX does not have sched_getaffinity
- return os.cpu_count()
-
-
-@dataclasses.dataclass
-class RegistrationStatistics:
- """Statistics for a registration."""
-
- time_spent_total: float
- rounds_total: int
- time_average: float
- time_spent: float
- hash_rate_perpetual: float
- hash_rate: float
- difficulty: int
- block_number: int
- block_hash: bytes
-
-
-class RegistrationStatisticsLogger:
- """Logs statistics for a registration."""
-
- status: Optional[rich_status.Status]
-
- def __init__(
- self,
- console: Optional[rich_console.Console] = None,
- output_in_place: bool = True,
- ) -> None:
- if console is None:
- console = Console()
-
- self.console = console
-
- if output_in_place:
- self.status = self.console.status("Solving")
- else:
- self.status = None
-
- def start(self) -> None:
- if self.status is not None:
- self.status.start()
-
- def stop(self) -> None:
- if self.status is not None:
- self.status.stop()
-
- def get_status_message(
- self, stats: RegistrationStatistics, verbose: bool = False
- ) -> str:
- """Generates the status message based on registration statistics."""
- message = (
- "Solving\n"
- + f"Time Spent (total): [bold white]{timedelta(seconds=stats.time_spent_total)}[/bold white]\n"
- + (
- f"Time Spent This Round: {timedelta(seconds=stats.time_spent)}\n"
- + f"Time Spent Average: {timedelta(seconds=stats.time_average)}\n"
- if verbose
- else ""
- )
- + f"Registration Difficulty: [bold white]{millify(stats.difficulty)}[/bold white]\n"
- + f"Iters (Inst/Perp): [bold white]{get_human_readable(stats.hash_rate, 'H')}/s / "
- + f"{get_human_readable(stats.hash_rate_perpetual, 'H')}/s[/bold white]\n"
- + f"Block Number: [bold white]{stats.block_number}[/bold white]\n"
- + f"Block Hash: [bold white]{stats.block_hash.encode('utf-8')}[/bold white]\n"
- )
- return message
-
- def update(self, stats: RegistrationStatistics, verbose: bool = False) -> None:
- if self.status is not None:
- self.status.update(self.get_status_message(stats, verbose=verbose))
- else:
- self.console.log(self.get_status_message(stats, verbose=verbose))
-
-
-def _solve_for_difficulty_fast(
- subtensor: "Subtensor",
- wallet: "Wallet",
- netuid: int,
- output_in_place: bool = True,
- num_processes: Optional[int] = None,
- update_interval: Optional[int] = None,
- n_samples: int = 10,
- alpha_: float = 0.80,
- log_verbose: bool = False,
-) -> Optional[POWSolution]:
- """
- Solves the POW for registration using multiprocessing.
-
- Args:
- subtensor (bittensor.core.subtensor.Subtensor): Subtensor instance to connect to for block information and to submit.
- wallet (bittensor_wallet.Wallet): wallet to use for registration.
- netuid (int): The netuid of the subnet to register to.
- output_in_place (bool): If true, prints the status in place. Otherwise, prints the status on a new line.
- num_processes (int): Number of processes to use.
- update_interval (int): Number of nonces to solve before updating block information.
- n_samples (int): The number of samples of the hash_rate to keep for the EWMA.
- alpha_ (float): The alpha for the EWMA for the hash_rate calculation.
- log_verbose (bool): If true, prints more verbose logging of the registration metrics.
-
- Note: The hash rate is calculated as an exponentially weighted moving average in order to make the measure more robust.
- Note: We can also modify the update interval to do smaller blocks of work, while still updating the block information after a different number of nonces, to increase the transparency of the process while still keeping the speed.
- """
- if num_processes is None:
- # get the number of allowed processes for this process
- num_processes = min(1, get_cpu_count())
-
- if update_interval is None:
- update_interval = 50_000
-
- limit = int(math.pow(2, 256)) - 1
-
- curr_block, curr_block_num, curr_diff = _Solver.create_shared_memory()
-
- # Establish communication queues
- # See the _Solver class for more information on the queues.
- stopEvent = multiprocessing.Event()
- stopEvent.clear()
-
- solution_queue = multiprocessing.Queue()
- finished_queues = [multiprocessing.Queue() for _ in range(num_processes)]
- check_block = multiprocessing.Lock()
-
- hotkey_bytes = (
- wallet.coldkeypub.public_key if netuid == -1 else wallet.hotkey.public_key
- )
- # Start consumers
- solvers = [
- _Solver(
- i,
- num_processes,
- update_interval,
- finished_queues[i],
- solution_queue,
- stopEvent,
- curr_block,
- curr_block_num,
- curr_diff,
- check_block,
- limit,
- )
- for i in range(num_processes)
- ]
-
- # Get first block
- block_number, difficulty, block_hash = _get_block_with_retry(
- subtensor=subtensor, netuid=netuid
- )
-
- block_bytes = bytes.fromhex(block_hash[2:])
- old_block_number = block_number
- # Set to current block
- _update_curr_block(
- curr_diff,
- curr_block,
- curr_block_num,
- block_number,
- block_bytes,
- difficulty,
- hotkey_bytes,
- check_block,
- )
-
- # Set new block events for each solver to start at the initial block
- for worker in solvers:
- worker.newBlockEvent.set()
-
- for worker in solvers:
- worker.start() # start the solver processes
-
- start_time = time.time() # time that the registration started
- time_last = start_time # time that the last work blocks completed
-
- curr_stats = RegistrationStatistics(
- time_spent_total=0.0,
- time_average=0.0,
- rounds_total=0,
- time_spent=0.0,
- hash_rate_perpetual=0.0,
- hash_rate=0.0,
- difficulty=difficulty,
- block_number=block_number,
- block_hash=block_hash,
- )
-
- start_time_perpetual = time.time()
-
- logger = RegistrationStatisticsLogger(output_in_place=output_in_place)
- logger.start()
-
- solution = None
-
- hash_rates = [0] * n_samples # The last n true hash_rates
- weights = [alpha_**i for i in range(n_samples)] # weights decay by alpha
-
- while netuid == -1 or not subtensor.is_hotkey_registered(
- netuid=netuid, hotkey_ss58=wallet.hotkey.ss58_address
- ):
- # Wait until a solver finds a solution
- try:
- solution = solution_queue.get(block=True, timeout=0.25)
- if solution is not None:
- break
- except Empty:
- # No solution found, try again
- pass
-
- # check for new block
- old_block_number = _check_for_newest_block_and_update(
- subtensor=subtensor,
- netuid=netuid,
- hotkey_bytes=hotkey_bytes,
- old_block_number=old_block_number,
- curr_diff=curr_diff,
- curr_block=curr_block,
- curr_block_num=curr_block_num,
- curr_stats=curr_stats,
- update_curr_block=_update_curr_block,
- check_block=check_block,
- solvers=solvers,
- )
-
- num_time = 0
- for finished_queue in finished_queues:
- try:
- proc_num = finished_queue.get(timeout=0.1)
- num_time += 1
-
- except Empty:
- continue
-
- time_now = time.time() # get current time
- time_since_last = time_now - time_last # get time since last work block(s)
- if num_time > 0 and time_since_last > 0.0:
- # create EWMA of the hash_rate to make measure more robust
-
- hash_rate_ = (num_time * update_interval) / time_since_last
- hash_rates.append(hash_rate_)
- hash_rates.pop(0) # remove the 0th data point
- curr_stats.hash_rate = sum(
- [hash_rates[i] * weights[i] for i in range(n_samples)]
- ) / (sum(weights))
-
- # update time last to now
- time_last = time_now
-
- curr_stats.time_average = (
- curr_stats.time_average * curr_stats.rounds_total
- + curr_stats.time_spent
- ) / (curr_stats.rounds_total + num_time)
- curr_stats.rounds_total += num_time
-
- # Update stats
- curr_stats.time_spent = time_since_last
- new_time_spent_total = time_now - start_time_perpetual
- curr_stats.hash_rate_perpetual = (
- curr_stats.rounds_total * update_interval
- ) / new_time_spent_total
- curr_stats.time_spent_total = new_time_spent_total
-
- # Update the logger
- logger.update(curr_stats, verbose=log_verbose)
-
- # exited while, solution contains the nonce or wallet is registered
- stopEvent.set() # stop all other processes
- logger.stop()
-
- # terminate and wait for all solvers to exit
- _terminate_workers_and_wait_for_exit(solvers)
-
- return solution
-
-
-@retry(Exception, tries=3, delay=1)
-def _get_block_with_retry(subtensor: "Subtensor", netuid: int) -> tuple[int, int, str]:
- """
- Gets the current block number, difficulty, and block hash from the substrate node.
-
- Args:
- subtensor (bittensor.core.subtensor.Subtensor): The subtensor object to use to get the block number, difficulty, and block hash.
- netuid (int): The netuid of the network to get the block number, difficulty, and block hash from.
-
- Returns:
- tuple[int, int, bytes]
- block_number (int): The current block number.
- difficulty (int): The current difficulty of the subnet.
- block_hash (bytes): The current block hash.
-
- Raises:
- Exception: If the block hash is None.
- ValueError: If the difficulty is None.
- """
- block_number = subtensor.get_current_block()
- difficulty = 1_000_000 if netuid == -1 else subtensor.difficulty(netuid=netuid)
- block_hash = subtensor.get_block_hash(block_number)
- if block_hash is None:
- raise Exception(
- "Network error. Could not connect to substrate to get block hash"
- )
- if difficulty is None:
- raise ValueError("Chain error. Difficulty is None")
- return block_number, difficulty, block_hash
-
-
-def _check_for_newest_block_and_update(
- subtensor: "Subtensor",
- netuid: int,
- old_block_number: int,
- hotkey_bytes: bytes,
- curr_diff: "multiprocessing.Array",
- curr_block: "multiprocessing.Array",
- curr_block_num: "multiprocessing.Value",
- update_curr_block: "Callable",
- check_block: "multiprocessing.Lock",
- solvers: Union[list["_Solver"], list["_CUDASolver"]],
- curr_stats: "RegistrationStatistics",
-) -> int:
- """
- Checks for a new block and updates the current block information if a new block is found.
-
- Args:
- subtensor (bittensor.core.subtensor.Subtensor): The subtensor object to use for getting the current block.
- netuid (int): The netuid to use for retrieving the difficulty.
- old_block_number (int): The old block number to check against.
- hotkey_bytes (bytes): The bytes of the hotkey's pubkey.
- curr_diff (multiprocessing.Array): The current difficulty as a multiprocessing array.
- curr_block (multiprocessing.Array): Where the current block is stored as a multiprocessing array.
- curr_block_num (multiprocessing.Value): Where the current block number is stored as a multiprocessing value.
- update_curr_block (typing.Callable): A function that updates the current block.
- check_block (multiprocessing.Lock): A mp lock that is used to check for a new block.
- solvers (list[bittensor.utils.registration._Solver]): A list of solvers to update the current block for.
- curr_stats (bittensor.utils.registration.RegistrationStatistics): The current registration statistics to update.
-
- Returns:
- (int) The current block number.
- """
- block_number = subtensor.get_current_block()
- if block_number != old_block_number:
- old_block_number = block_number
- # update block information
- block_number, difficulty, block_hash = _get_block_with_retry(
- subtensor=subtensor, netuid=netuid
- )
- block_bytes = bytes.fromhex(block_hash[2:])
-
- update_curr_block(
- curr_diff,
- curr_block,
- curr_block_num,
- block_number,
- block_bytes,
- difficulty,
- hotkey_bytes,
- check_block,
- )
- # Set new block events for each solver
-
- for worker in solvers:
- worker.newBlockEvent.set()
-
- # update stats
- curr_stats.block_number = block_number
- curr_stats.block_hash = block_hash
- curr_stats.difficulty = difficulty
-
- return old_block_number
-
-
-def _solve_for_difficulty_fast_cuda(
- subtensor: "Subtensor",
- wallet: "Wallet",
- netuid: int,
- output_in_place: bool = True,
- update_interval: int = 50_000,
- tpb: int = 512,
- dev_id: Union[list[int], int] = 0,
- n_samples: int = 10,
- alpha_: float = 0.80,
- log_verbose: bool = False,
-) -> Optional["POWSolution"]:
- """
- Solves the registration fast using CUDA.
-
- Args:
- subtensor (bittensor.core.subtensor.Subtensor): The subtensor node to grab blocks.
- wallet (bittensor_wallet.Wallet): The wallet to register.
- netuid (int): The netuid of the subnet to register to.
- output_in_place (bool) If true, prints the output in place, otherwise prints to new lines.
- update_interval (int): The number of nonces to try before checking for more blocks.
- tpb (int): The number of threads per block. CUDA param that should match the GPU capability
- dev_id (Union[list[int], int]): The CUDA device IDs to execute the registration on, either a single device or a list of devices.
- n_samples (int): The number of samples of the hash_rate to keep for the EWMA.
- alpha_ (float): The alpha for the EWMA for the hash_rate calculation.
- log_verbose (bool): If true, prints more verbose logging of the registration metrics.
-
- Note: The hash rate is calculated as an exponentially weighted moving average in order to make the measure more robust.
- """
- if isinstance(dev_id, int):
- dev_id = [dev_id]
- elif dev_id is None:
- dev_id = [0]
-
- if update_interval is None:
- update_interval = 50_000
-
- if not torch.cuda.is_available():
- raise Exception("CUDA not available")
-
- limit = int(math.pow(2, 256)) - 1
-
- # Set mp start to use spawn so CUDA doesn't complain
- with _UsingSpawnStartMethod(force=True):
- curr_block, curr_block_num, curr_diff = _CUDASolver.create_shared_memory()
-
- # Create a worker per CUDA device
- num_processes = len(dev_id)
-
- # Establish communication queues
- stopEvent = multiprocessing.Event()
- stopEvent.clear()
- solution_queue = multiprocessing.Queue()
- finished_queues = [multiprocessing.Queue() for _ in range(num_processes)]
- check_block = multiprocessing.Lock()
-
- hotkey_bytes = wallet.hotkey.public_key
- # Start workers
- solvers = [
- _CUDASolver(
- i,
- num_processes,
- update_interval,
- finished_queues[i],
- solution_queue,
- stopEvent,
- curr_block,
- curr_block_num,
- curr_diff,
- check_block,
- limit,
- dev_id[i],
- tpb,
- )
- for i in range(num_processes)
- ]
-
- # Get first block
- block_number, difficulty, block_hash = _get_block_with_retry(
- subtensor=subtensor, netuid=netuid
- )
-
- block_bytes = bytes.fromhex(block_hash[2:])
- old_block_number = block_number
-
- # Set to current block
- _update_curr_block(
- curr_diff,
- curr_block,
- curr_block_num,
- block_number,
- block_bytes,
- difficulty,
- hotkey_bytes,
- check_block,
- )
-
- # Set new block events for each solver to start at the initial block
- for worker in solvers:
- worker.newBlockEvent.set()
-
- for worker in solvers:
- worker.start() # start the solver processes
-
- start_time = time.time() # time that the registration started
- time_last = start_time # time that the last work blocks completed
-
- curr_stats = RegistrationStatistics(
- time_spent_total=0.0,
- time_average=0.0,
- rounds_total=0,
- time_spent=0.0,
- hash_rate_perpetual=0.0,
- hash_rate=0.0, # EWMA hash_rate (H/s)
- difficulty=difficulty,
- block_number=block_number,
- block_hash=block_hash,
- )
-
- start_time_perpetual = time.time()
-
- logger = RegistrationStatisticsLogger(output_in_place=output_in_place)
- logger.start()
-
- hash_rates = [0] * n_samples # The last n true hash_rates
- weights = [alpha_**i for i in range(n_samples)] # weights decay by alpha
-
- solution = None
- while netuid == -1 or not subtensor.is_hotkey_registered(
- netuid=netuid, hotkey_ss58=wallet.hotkey.ss58_address
- ):
- # Wait until a solver finds a solution
- try:
- solution = solution_queue.get(block=True, timeout=0.15)
- if solution is not None:
- break
- except Empty:
- # No solution found, try again
- pass
-
- # check for new block
- old_block_number = _check_for_newest_block_and_update(
- subtensor=subtensor,
- netuid=netuid,
- hotkey_bytes=hotkey_bytes,
- curr_diff=curr_diff,
- curr_block=curr_block,
- curr_block_num=curr_block_num,
- old_block_number=old_block_number,
- curr_stats=curr_stats,
- update_curr_block=_update_curr_block,
- check_block=check_block,
- solvers=solvers,
- )
-
- num_time = 0
- # Get times for each solver
- for finished_queue in finished_queues:
- try:
- proc_num = finished_queue.get(timeout=0.1)
- num_time += 1
-
- except Empty:
- continue
-
- time_now = time.time() # get current time
- time_since_last = time_now - time_last # get time since last work block(s)
- if num_time > 0 and time_since_last > 0.0:
- # create EWMA of the hash_rate to make measure more robust
-
- hash_rate_ = (num_time * tpb * update_interval) / time_since_last
- hash_rates.append(hash_rate_)
- hash_rates.pop(0) # remove the 0th data point
- curr_stats.hash_rate = sum(
- [hash_rates[i] * weights[i] for i in range(n_samples)]
- ) / (sum(weights))
-
- # update time last to now
- time_last = time_now
-
- curr_stats.time_average = (
- curr_stats.time_average * curr_stats.rounds_total
- + curr_stats.time_spent
- ) / (curr_stats.rounds_total + num_time)
- curr_stats.rounds_total += num_time
-
- # Update stats
- curr_stats.time_spent = time_since_last
- new_time_spent_total = time_now - start_time_perpetual
- curr_stats.hash_rate_perpetual = (
- curr_stats.rounds_total * (tpb * update_interval)
- ) / new_time_spent_total
- curr_stats.time_spent_total = new_time_spent_total
-
- # Update the logger
- logger.update(curr_stats, verbose=log_verbose)
-
- # exited while, found_solution contains the nonce or wallet is registered
-
- stopEvent.set() # stop all other processes
- logger.stop()
-
- # terminate and wait for all solvers to exit
- _terminate_workers_and_wait_for_exit(solvers)
-
- return solution
-
-
-def _terminate_workers_and_wait_for_exit(
- workers: list[Union[multiprocessing.Process, QueueType]],
-) -> None:
- for worker in workers:
- if isinstance(worker, QueueType):
- worker.join_thread()
- else:
- try:
- worker.join(3.0)
- except subprocess.TimeoutExpired:
- worker.terminate()
- try:
- worker.close()
- except ValueError:
- worker.terminate()
-
-
-def create_pow(
- subtensor: "Subtensor",
- wallet: "Wallet",
- netuid: int,
- output_in_place: bool = True,
- cuda: bool = False,
- dev_id: Union[list[int], int] = 0,
- tpb: int = 256,
- num_processes: Optional[int] = None,
- update_interval: Optional[int] = None,
- log_verbose: bool = False,
-) -> Optional[dict[str, Any]]:
- """
- Creates a proof of work for the given subtensor and wallet.
-
- Args:
- subtensor (bittensor.core.subtensor.Subtensor): The subtensor to create a proof of work for.
- wallet (bittensor_wallet.Wallet): The wallet to create a proof of work for.
- netuid (int): The netuid for the subnet to create a proof of work for.
- output_in_place (bool): If true, prints the progress of the proof of work to the console in-place. Meaning the progress is printed on the same lines. Default is ``True``.
- cuda (bool): If true, uses CUDA to solve the proof of work. Default is ``False``.
- dev_id (Union[List[int], int]): The CUDA device id(s) to use. If cuda is true and dev_id is a list, then multiple CUDA devices will be used to solve the proof of work. Default is ``0``.
- tpb (int): The number of threads per block to use when solving the proof of work. Should be a multiple of 32. Default is ``256``.
- num_processes (Optional[int]): The number of processes to use when solving the proof of work. If None, then the number of processes is equal to the number of CPU cores. Default is None.
- update_interval (Optional[int]): The number of nonces to run before checking for a new block. Default is ``None``.
- log_verbose (bool): If true, prints the progress of the proof of work more verbosely. Default is ``False``.
-
- Returns:
- Optional[Dict[str, Any]]: The proof of work solution or None if the wallet is already registered or there is a different error.
-
- Raises:
- ValueError: If the subnet does not exist.
- """
- if netuid != -1:
- if not subtensor.subnet_exists(netuid=netuid):
- raise ValueError(f"Subnet {netuid} does not exist.")
-
- if cuda:
- solution: Optional[POWSolution] = _solve_for_difficulty_fast_cuda(
- subtensor,
- wallet,
- netuid=netuid,
- output_in_place=output_in_place,
- dev_id=dev_id,
- tpb=tpb,
- update_interval=update_interval,
- log_verbose=log_verbose,
- )
- else:
- solution: Optional[POWSolution] = _solve_for_difficulty_fast(
- subtensor,
- wallet,
- netuid=netuid,
- output_in_place=output_in_place,
- num_processes=num_processes,
- update_interval=update_interval,
- log_verbose=log_verbose,
- )
- return solution
diff --git a/bittensor/utils/registration/__init__.py b/bittensor/utils/registration/__init__.py
new file mode 100644
index 0000000000..37a913e20a
--- /dev/null
+++ b/bittensor/utils/registration/__init__.py
@@ -0,0 +1,10 @@
+from bittensor.utils.registration.pow import (
+ create_pow,
+ legacy_torch_api_compat,
+ log_no_torch_error,
+ torch,
+ use_torch,
+ LazyLoadedTorch,
+ POWSolution,
+)
+from bittensor.utils.registration.async_pow import create_pow_async
diff --git a/bittensor/utils/registration/async_pow.py b/bittensor/utils/registration/async_pow.py
new file mode 100644
index 0000000000..02817620ac
--- /dev/null
+++ b/bittensor/utils/registration/async_pow.py
@@ -0,0 +1,537 @@
+"""This module provides async utilities for solving Proof-of-Work (PoW) challenges in Bittensor network."""
+
+import math
+import time
+from multiprocessing import Event, Lock, Array, Value, Queue
+from queue import Empty
+from typing import Callable, Union, Optional, TYPE_CHECKING
+
+from retry import retry
+from substrateinterface.exceptions import SubstrateRequestException
+
+from bittensor.utils.registration.pow import (
+ get_cpu_count,
+ update_curr_block,
+ terminate_workers_and_wait_for_exit,
+ CUDASolver,
+ LazyLoadedTorch,
+ RegistrationStatistics,
+ RegistrationStatisticsLogger,
+ Solver,
+ UsingSpawnStartMethod,
+)
+
+if TYPE_CHECKING:
+ from bittensor.core.async_subtensor import AsyncSubtensor
+ from bittensor_wallet import Wallet
+ from bittensor.utils.registration import POWSolution
+ import torch
+else:
+ torch = LazyLoadedTorch()
+
+
+@retry(Exception, tries=3, delay=1)
+async def _get_block_with_retry(
+ subtensor: "AsyncSubtensor", netuid: int
+) -> tuple[int, int, str]:
+ """
+ Gets the current block number, difficulty, and block hash from the substrate node.
+
+ Args:
+ subtensor (bittensor.core.async_subtensor.AsyncSubtensor): The subtensor object to use to get the block number, difficulty, and block hash.
+ netuid (int): The netuid of the network to get the block number, difficulty, and block hash from.
+
+ Returns:
+ The current block number, difficulty of the subnet, block hash
+
+ Raises:
+ Exception: If the block hash is None.
+ ValueError: If the difficulty is None.
+ """
+ block = await subtensor.substrate.get_block()
+ block_hash = block["header"]["hash"]
+ block_number = block["header"]["number"]
+ try:
+ difficulty = (
+ 1_000_000
+ if netuid == -1
+ else int(
+ await subtensor.get_hyperparameter(
+ param_name="Difficulty", netuid=netuid, block_hash=block_hash
+ )
+ )
+ )
+ except TypeError:
+ raise ValueError("Chain error. Difficulty is None")
+ except SubstrateRequestException:
+ raise Exception(
+ "Network error. Could not connect to substrate to get block hash"
+ )
+ return block_number, difficulty, block_hash
+
+
+async def _check_for_newest_block_and_update(
+ subtensor: "AsyncSubtensor",
+ netuid: int,
+ old_block_number: int,
+ hotkey_bytes: bytes,
+ curr_diff: Array,
+ curr_block: Array,
+ curr_block_num: Value,
+ update_curr_block_: "Callable",
+ check_block: Lock,
+ solvers: list[Solver],
+ curr_stats: "RegistrationStatistics",
+) -> int:
+ """
+ Check for the newest block and update block-related information and states across solvers if a new block is detected.
+
+ Args:
+ subtensor (AsyncSubtensor): The subtensor instance interface.
+ netuid (int): The network UID for the blockchain.
+ old_block_number (int): The previously known block number.
+ hotkey_bytes (bytes): The bytes representation of the hotkey.
+ curr_diff (Array): The current difficulty level.
+ curr_block (Array): The current block information.
+ curr_block_num (Value): The current block number.
+ update_curr_block_ (Callable): Function to update current block information.
+ check_block (Lock): Lock object for synchronizing block checking.
+ solvers (list[Solver]): List of solvers to notify of new blocks.
+ curr_stats (RegistrationStatistics): Current registration statistics to update.
+
+ Returns:
+ int: The updated block number which is the same as the new block
+ number if it was detected, otherwise the old block number.
+ """
+ block_number = await subtensor.substrate.get_block_number(None)
+ if block_number != old_block_number:
+ old_block_number = block_number
+ # update block information
+ block_number, difficulty, block_hash = await _get_block_with_retry(
+ subtensor=subtensor, netuid=netuid
+ )
+ block_bytes = bytes.fromhex(block_hash[2:])
+
+ update_curr_block_(
+ curr_diff,
+ curr_block,
+ curr_block_num,
+ block_number,
+ block_bytes,
+ difficulty,
+ hotkey_bytes,
+ check_block,
+ )
+ # Set new block events for each solver
+
+ for worker in solvers:
+ worker.newBlockEvent.set()
+
+ # update stats
+ curr_stats.block_number = block_number
+ curr_stats.block_hash = block_hash
+ curr_stats.difficulty = difficulty
+
+ return old_block_number
+
+
+async def _block_solver(
+ subtensor: "AsyncSubtensor",
+ wallet: "Wallet",
+ num_processes: int,
+ netuid: int,
+ dev_id: list[int],
+ tpb: int,
+ update_interval: int,
+ curr_block,
+ curr_block_num,
+ curr_diff,
+ n_samples,
+ alpha_,
+ output_in_place,
+ log_verbose,
+ cuda: bool,
+):
+ """Shared code used by the Solvers to solve the POW solution."""
+ limit = int(math.pow(2, 256)) - 1
+
+ if cuda:
+ num_processes = len(dev_id)
+
+ # Establish communication queues
+ # See the _Solver class for more information on the queues.
+ stop_event = Event()
+ stop_event.clear()
+
+ solution_queue = Queue()
+ finished_queues = [Queue() for _ in range(num_processes)]
+ check_block = Lock()
+
+ hotkey_bytes = (
+ wallet.coldkeypub.public_key if netuid == -1 else wallet.hotkey.public_key
+ )
+
+ if cuda:
+ # Create a worker per CUDA device
+ solvers = [
+ CUDASolver(
+ i,
+ num_processes,
+ update_interval,
+ finished_queues[i],
+ solution_queue,
+ stop_event,
+ curr_block,
+ curr_block_num,
+ curr_diff,
+ check_block,
+ limit,
+ dev_id[i],
+ tpb,
+ )
+ for i in range(num_processes)
+ ]
+ else:
+ # Start consumers
+ solvers = [
+ Solver(
+ i,
+ num_processes,
+ update_interval,
+ finished_queues[i],
+ solution_queue,
+ stop_event,
+ curr_block,
+ curr_block_num,
+ curr_diff,
+ check_block,
+ limit,
+ )
+ for i in range(num_processes)
+ ]
+
+ # Get first block
+ block_number, difficulty, block_hash = await _get_block_with_retry(
+ subtensor=subtensor, netuid=netuid
+ )
+
+ block_bytes = bytes.fromhex(block_hash[2:])
+ old_block_number = block_number
+ # Set to current block
+ update_curr_block(
+ curr_diff,
+ curr_block,
+ curr_block_num,
+ block_number,
+ block_bytes,
+ difficulty,
+ hotkey_bytes,
+ check_block,
+ )
+
+ # Set new block events for each solver to start at the initial block
+ for worker in solvers:
+ worker.newBlockEvent.set()
+
+ for worker in solvers:
+ worker.start() # start the solver processes
+
+ start_time = time.time() # time that the registration started
+ time_last = start_time # time that the last work blocks completed
+
+ curr_stats = RegistrationStatistics(
+ time_spent_total=0.0,
+ time_average=0.0,
+ rounds_total=0,
+ time_spent=0.0,
+ hash_rate_perpetual=0.0,
+ hash_rate=0.0,
+ difficulty=difficulty,
+ block_number=block_number,
+ block_hash=block_hash,
+ )
+
+ start_time_perpetual = time.time()
+
+ logger = RegistrationStatisticsLogger(output_in_place=output_in_place)
+ logger.start()
+
+ solution = None
+
+ hash_rates = [0] * n_samples # The last n true hash_rates
+ weights = [alpha_**i for i in range(n_samples)] # weights decay by alpha
+
+ timeout = 0.15 if cuda else 0.15
+ while netuid == -1 or not await subtensor.is_hotkey_registered(
+ netuid, wallet.hotkey.ss58_address
+ ):
+ # Wait until a solver finds a solution
+ try:
+ solution = solution_queue.get(block=True, timeout=timeout)
+ if solution is not None:
+ break
+ except Empty:
+ # No solution found, try again
+ pass
+
+ # check for new block
+ old_block_number = await _check_for_newest_block_and_update(
+ subtensor=subtensor,
+ netuid=netuid,
+ hotkey_bytes=hotkey_bytes,
+ old_block_number=old_block_number,
+ curr_diff=curr_diff,
+ curr_block=curr_block,
+ curr_block_num=curr_block_num,
+ curr_stats=curr_stats,
+ update_curr_block_=update_curr_block,
+ check_block=check_block,
+ solvers=solvers,
+ )
+
+ num_time = 0
+ for finished_queue in finished_queues:
+ try:
+ finished_queue.get(timeout=0.1)
+ num_time += 1
+
+ except Empty:
+ continue
+
+ time_now = time.time() # get current time
+ time_since_last = time_now - time_last # get time since last work block(s)
+ if num_time > 0 and time_since_last > 0.0:
+ # create EWMA of the hash_rate to make measure more robust
+
+ if cuda:
+ hash_rate_ = (num_time * tpb * update_interval) / time_since_last
+ else:
+ hash_rate_ = (num_time * update_interval) / time_since_last
+ hash_rates.append(hash_rate_)
+ hash_rates.pop(0) # remove the 0th data point
+ curr_stats.hash_rate = sum(
+ [hash_rates[i] * weights[i] for i in range(n_samples)]
+ ) / (sum(weights))
+
+ # update time last to now
+ time_last = time_now
+
+ curr_stats.time_average = (
+ curr_stats.time_average * curr_stats.rounds_total
+ + curr_stats.time_spent
+ ) / (curr_stats.rounds_total + num_time)
+ curr_stats.rounds_total += num_time
+
+ # Update stats
+ curr_stats.time_spent = time_since_last
+ new_time_spent_total = time_now - start_time_perpetual
+ if cuda:
+ curr_stats.hash_rate_perpetual = (
+ curr_stats.rounds_total * (tpb * update_interval)
+ ) / new_time_spent_total
+ else:
+ curr_stats.hash_rate_perpetual = (
+ curr_stats.rounds_total * update_interval
+ ) / new_time_spent_total
+ curr_stats.time_spent_total = new_time_spent_total
+
+ # Update the logger
+ logger.update(curr_stats, verbose=log_verbose)
+
+ # exited while, solution contains the nonce or wallet is registered
+ stop_event.set() # stop all other processes
+ logger.stop()
+
+ # terminate and wait for all solvers to exit
+ terminate_workers_and_wait_for_exit(solvers)
+
+ return solution
+
+
+async def _solve_for_difficulty_fast_cuda(
+ subtensor: "AsyncSubtensor",
+ wallet: "Wallet",
+ netuid: int,
+ output_in_place: bool = True,
+ update_interval: int = 50_000,
+ tpb: int = 512,
+ dev_id: Union[list[int], int] = 0,
+ n_samples: int = 10,
+ alpha_: float = 0.80,
+ log_verbose: bool = False,
+) -> Optional["POWSolution"]:
+ """
+ Solves the registration fast using CUDA
+
+ Args:
+ subtensor (bittensor.core.async_subtensor.AsyncSubtensor): The subtensor object to use to get the block number, difficulty, and block hash.
+ wallet (bittensor_wallet.Wallet): The wallet to register
+ netuid (int): The netuid of the subnet to register to.
+ output_in_place (bool): If true, prints the output in place, otherwise prints to new lines
+ update_interval (int): The number of nonces to try before checking for more blocks
+ tpb (int): The number of threads per block. CUDA param that should match the GPU capability
+ dev_id (Union[list[int], int]): The CUDA device IDs to execute the registration on, either a single device or a list of devices
+ n_samples (int): The number of samples of the hash_rate to keep for the EWMA
+ alpha_ (float): The alpha for the EWMA for the hash_rate calculation
+ log_verbose (bool): If true, prints more verbose logging of the registration metrics.
+
+ Note:
+ The hash rate is calculated as an exponentially weighted moving average in order to make the measure more robust.
+ """
+ if isinstance(dev_id, int):
+ dev_id = [dev_id]
+ elif dev_id is None:
+ dev_id = [0]
+
+ num_processes = min(1, get_cpu_count())
+
+ if update_interval is None:
+ update_interval = 50_000
+
+ if not torch.cuda.is_available():
+ raise Exception("CUDA not available")
+
+ # Set mp start to use spawn so CUDA doesn't complain
+ with UsingSpawnStartMethod(force=True):
+ curr_block, curr_block_num, curr_diff = CUDASolver.create_shared_memory()
+
+ solution = await _block_solver(
+ subtensor=subtensor,
+ wallet=wallet,
+ num_processes=num_processes,
+ netuid=netuid,
+ dev_id=dev_id,
+ tpb=tpb,
+ update_interval=update_interval,
+ curr_block=curr_block,
+ curr_block_num=curr_block_num,
+ curr_diff=curr_diff,
+ n_samples=n_samples,
+ alpha_=alpha_,
+ output_in_place=output_in_place,
+ log_verbose=log_verbose,
+ cuda=True,
+ )
+
+ return solution
+
+
+async def _solve_for_difficulty_fast(
+ subtensor: "AsyncSubtensor",
+ wallet: "Wallet",
+ netuid: int,
+ output_in_place: bool = True,
+ num_processes: Optional[int] = None,
+ update_interval: Optional[int] = None,
+ n_samples: int = 10,
+ alpha_: float = 0.80,
+ log_verbose: bool = False,
+) -> Optional["POWSolution"]:
+ """
+ Solves the POW for registration using multiprocessing.
+
+ Args:
+ subtensor (bittensor.core.async_subtensor.AsyncSubtensor): The subtensor object to use to get the block number, difficulty, and block hash.
+ wallet (bittensor_wallet.Wallet): wallet to use for registration.
+ netuid (int): The netuid of the subnet to register to.
+ output_in_place (bool): If true, prints the status in place. Otherwise, prints the status on a new line.
+ num_processes (Optional[int]): Number of processes to use.
+ update_interval (Optional[int]): Number of nonces to solve before updating block information.
+ n_samples (int): The number of samples of the hash_rate to keep for the EWMA
+ alpha_ (float): The alpha for the EWMA for the hash_rate calculation
+ log_verbose (bool): If true, prints more verbose logging of the registration metrics.
+
+ Notes:
+ The hash rate is calculated as an exponentially weighted moving average in order to make the measure more robust.
+ We can also modify the update interval to do smaller blocks of work, while still updating the block information after a different number of nonces, to increase the transparency of the process while still keeping the speed.
+ """
+ if not num_processes:
+ # get the number of allowed processes for this process
+ num_processes = min(1, get_cpu_count())
+
+ if update_interval is None:
+ update_interval = 50_000
+
+ curr_block, curr_block_num, curr_diff = Solver.create_shared_memory()
+
+ solution = await _block_solver(
+ subtensor=subtensor,
+ wallet=wallet,
+ num_processes=num_processes,
+ netuid=netuid,
+ dev_id=None,
+ tpb=None,
+ update_interval=update_interval,
+ curr_block=curr_block,
+ curr_block_num=curr_block_num,
+ curr_diff=curr_diff,
+ n_samples=n_samples,
+ alpha_=alpha_,
+ output_in_place=output_in_place,
+ log_verbose=log_verbose,
+ cuda=False,
+ )
+
+ return solution
+
+
+async def create_pow_async(
+ subtensor: "AsyncSubtensor",
+ wallet: "Wallet",
+ netuid: int,
+ output_in_place: bool = True,
+ cuda: bool = False,
+ dev_id: Union[list[int], int] = 0,
+ tpb: int = 256,
+ num_processes: int = None,
+ update_interval: int = None,
+ log_verbose: bool = False,
+) -> "POWSolution":
+ """
+ Creates a proof of work for the given subtensor and wallet.
+
+ Args:
+ subtensor (bittensor.core.async_subtensor.AsyncSubtensor): The subtensor object to use to get the block number, difficulty, and block hash.
+ wallet (bittensor_wallet.Wallet): The wallet to create a proof of work for.
+ netuid (int): The netuid for the subnet to create a proof of work for.
+ output_in_place (bool): If true, prints the progress of the proof of work to the console in-place. Meaning the progress is printed on the same lines.
+ cuda (bool): If true, uses CUDA to solve the proof of work.
+ dev_id (Union[list[int], int]): The CUDA device id(s) to use. If cuda is true and dev_id is a list, then multiple CUDA devices will be used to solve the proof of work.
+ tpb (int): The number of threads per block to use when solving the proof of work. Should be a multiple of 32.
+ num_processes (int): The number of processes to use when solving the proof of work. If None, then the number of processes is equal to the number of CPU cores.
+ update_interval (int): The number of nonces to run before checking for a new block.
+ log_verbose (bool): If true, prints the progress of the proof of work more verbosely.
+
+ Returns:
+ The proof of work solution or None if the wallet is already registered or there is a different error.
+
+ Raises:
+ ValueError: If the subnet does not exist.
+ """
+ if netuid != -1:
+ if not await subtensor.subnet_exists(netuid=netuid):
+ raise ValueError(f"Subnet {netuid} does not exist")
+ solution: Optional[POWSolution]
+ if cuda:
+ solution = await _solve_for_difficulty_fast_cuda(
+ subtensor=subtensor,
+ wallet=wallet,
+ netuid=netuid,
+ output_in_place=output_in_place,
+ dev_id=dev_id,
+ tpb=tpb,
+ update_interval=update_interval,
+ log_verbose=log_verbose,
+ )
+ else:
+ solution = await _solve_for_difficulty_fast(
+ subtensor=subtensor,
+ wallet=wallet,
+ netuid=netuid,
+ output_in_place=output_in_place,
+ num_processes=num_processes,
+ update_interval=update_interval,
+ log_verbose=log_verbose,
+ )
+
+ return solution
diff --git a/bittensor/utils/registration/pow.py b/bittensor/utils/registration/pow.py
new file mode 100644
index 0000000000..c96295b0cd
--- /dev/null
+++ b/bittensor/utils/registration/pow.py
@@ -0,0 +1,1138 @@
+"""This module provides utilities for solving Proof-of-Work (PoW) challenges in Bittensor network."""
+
+import binascii
+from dataclasses import dataclass
+import functools
+import hashlib
+import math
+import multiprocessing as mp
+import os
+import random
+import subprocess
+import time
+from datetime import timedelta
+from multiprocessing.queues import Queue as QueueType
+from queue import Empty, Full
+from typing import Any, Callable, Optional, Union, TYPE_CHECKING
+
+import numpy
+from Crypto.Hash import keccak
+from rich import console as rich_console, status as rich_status
+from rich.console import Console
+
+from bittensor.utils.btlogging import logging
+from bittensor.utils.formatting import get_human_readable, millify
+from bittensor.utils.registration.register_cuda import solve_cuda
+
+
+def use_torch() -> bool:
+ """Force the use of torch over numpy for certain operations."""
+ return True if os.getenv("USE_TORCH") == "1" else False
+
+
+def legacy_torch_api_compat(func):
+ """
+ Convert function operating on numpy Input&Output to legacy torch Input&Output API if `use_torch()` is True.
+
+ Args:
+ func (function): Function with numpy Input/Output to be decorated.
+
+ Returns:
+ decorated (function): Decorated function.
+ """
+
+ @functools.wraps(func)
+ def decorated(*args, **kwargs):
+ if use_torch():
+ # if argument is a Torch tensor, convert it to numpy
+ args = [
+ arg.cpu().numpy() if isinstance(arg, torch.Tensor) else arg
+ for arg in args
+ ]
+ kwargs = {
+ key: value.cpu().numpy() if isinstance(value, torch.Tensor) else value
+ for key, value in kwargs.items()
+ }
+ ret = func(*args, **kwargs)
+ if use_torch():
+ # if return value is a numpy array, convert it to Torch tensor
+ if isinstance(ret, numpy.ndarray):
+ ret = torch.from_numpy(ret)
+ return ret
+
+ return decorated
+
+
+@functools.cache
+def _get_real_torch():
+ try:
+ import torch as _real_torch
+ except ImportError:
+ _real_torch = None
+ return _real_torch
+
+
+def log_no_torch_error():
+ logging.error(
+ "This command requires torch. You can install torch for bittensor"
+ ' with `pip install bittensor[torch]` or `pip install ".[torch]"`'
+ " if installing from source, and then run the command with USE_TORCH=1 {command}"
+ )
+
+
+class LazyLoadedTorch:
+ """A lazy-loading proxy for the torch module."""
+
+ def __bool__(self):
+ return bool(_get_real_torch())
+
+ def __getattr__(self, name):
+ if real_torch := _get_real_torch():
+ return getattr(real_torch, name)
+ else:
+ log_no_torch_error()
+ raise ImportError("torch not installed")
+
+
+if TYPE_CHECKING:
+ import torch
+ from bittensor.core.subtensor import Subtensor
+ from bittensor.core.async_subtensor import AsyncSubtensor
+ from bittensor_wallet import Wallet
+else:
+ torch = LazyLoadedTorch()
+
+
+def _hex_bytes_to_u8_list(hex_bytes: bytes) -> list[int]:
+ """ """
+ return [int(hex_bytes[i : i + 2], 16) for i in range(0, len(hex_bytes), 2)]
+
+
+def _create_seal_hash(block_and_hotkey_hash_bytes: bytes, nonce: int) -> bytes:
+ """
+ Create a cryptographic seal hash from the given block and hotkey hash bytes and nonce.
+
+ This function generates a seal hash by combining the given block and hotkey hash bytes with a nonce.
+ It first converts the nonce to a byte representation, then concatenates it with the first 64 hex characters of the block and hotkey hash bytes. The result is then hashed using SHA-256 followed by the Keccak-256 algorithm to produce the final seal hash.
+
+ Args:
+ block_and_hotkey_hash_bytes (bytes): The combined hash bytes of the block and hotkey.
+ nonce (int): The nonce value used for hashing.
+
+ Returns:
+ The resulting seal hash.
+ """
+ nonce_bytes = binascii.hexlify(nonce.to_bytes(8, "little"))
+ pre_seal = nonce_bytes + binascii.hexlify(block_and_hotkey_hash_bytes)[:64]
+ seal_sh256 = hashlib.sha256(bytearray(_hex_bytes_to_u8_list(pre_seal))).digest()
+ kec = keccak.new(digest_bits=256)
+ seal = kec.update(seal_sh256).digest()
+ return seal
+
+
+def _seal_meets_difficulty(seal: bytes, difficulty: int, limit: int) -> bool:
+ """Determines if a seal meets the specified difficulty."""
+ seal_number = int.from_bytes(seal, "big")
+ product = seal_number * difficulty
+ return product < limit
+
+
+@dataclass
+class POWSolution:
+ """A solution to the registration PoW problem."""
+
+ nonce: int
+ block_number: int
+ difficulty: int
+ seal: bytes
+
+ def is_stale(self, subtensor: "Subtensor") -> bool:
+ """
+ Synchronous implementation. Returns True if the POW is stale.
+
+ This means the block the POW is solved for is within 3 blocks of the current block.
+ """
+ return self.block_number < subtensor.get_current_block() - 3
+
+ async def is_stale_async(self, subtensor: "AsyncSubtensor") -> bool:
+ """
+ Asynchronous implementation. Returns True if the POW is stale.
+
+ This means the block the POW is solved for is within 3 blocks of the current block.
+ """
+ current_block = await subtensor.substrate.get_block_number(None)
+ return self.block_number < current_block - 3
+
+
+class UsingSpawnStartMethod:
+ def __init__(self, force: bool = False):
+ self._old_start_method = None
+ self._force = force
+
+ def __enter__(self):
+ self._old_start_method = mp.get_start_method(allow_none=True)
+ if self._old_start_method is None:
+ self._old_start_method = "spawn" # default to spawn
+
+ mp.set_start_method("spawn", force=self._force)
+
+ def __exit__(self, *args):
+ # restore the old start method
+ mp.set_start_method(self._old_start_method, force=True)
+
+
+class _SolverBase(mp.Process):
+ """
+ A process that solves the registration PoW problem.
+
+ Args:
+ proc_num (int): The number of the process being created.
+ num_proc (int): The total number of processes running.
+ update_interval (int): The number of nonces to try to solve before checking for a new block.
+ finished_queue (multiprocessing.Queue): The queue to put the process number when a process finishes each update_interval. Used for calculating the average time per update_interval across all processes.
+ solution_queue (multiprocessing.Queue): The queue to put the solution the process has found during the pow solve.
+ stopEvent (multiprocessing.Event): The event to set by the main process when all the solver processes should stop. The solver process will check for the event after each update_interval. The solver process will stop when the event is set. Used to stop the solver processes when a solution is found.
+ curr_block (multiprocessing.Array): The array containing this process's current block hash. The main process will set the array to the new block hash when a new block is finalized in the network. The solver process will get the new block hash from this array when newBlockEvent is set.
+ curr_block_num (multiprocessing.Value): The value containing this process's current block number. The main process will set the value to the new block number when a new block is finalized in the network. The solver process will get the new block number from this value when newBlockEvent is set.
+ curr_diff (multiprocessing.Array): The array containing this process's current difficulty. The main process will set the array to the new difficulty when a new block is finalized in the network. The solver process will get the new difficulty from this array when newBlockEvent is set.
+ check_block (multiprocessing.Lock): The lock to prevent this process from getting the new block data while the main process is updating the data.
+ limit (int): The limit of the pow solve for a valid solution.
+ """
+
+ proc_num: int
+ num_proc: int
+ update_interval: int
+ finished_queue: "mp.Queue"
+ solution_queue: "mp.Queue"
+ # newBlockEvent: "mp.Event"
+ newBlockEvent: "mp.Event"
+ stopEvent: "mp.Event"
+ hotkey_bytes: bytes
+ curr_block: "mp.Array"
+ curr_block_num: "mp.Value"
+ curr_diff: "mp.Array"
+ check_block: "mp.Lock"
+ limit: int
+
+ def __init__(
+ self,
+ proc_num,
+ num_proc,
+ update_interval,
+ finished_queue,
+ solution_queue,
+ stopEvent,
+ curr_block,
+ curr_block_num,
+ curr_diff,
+ check_block,
+ limit,
+ ):
+ mp.Process.__init__(self, daemon=True)
+ self.proc_num = proc_num
+ self.num_proc = num_proc
+ self.update_interval = update_interval
+ self.finished_queue = finished_queue
+ self.solution_queue = solution_queue
+ self.newBlockEvent = mp.Event()
+ self.newBlockEvent.clear()
+ self.curr_block = curr_block
+ self.curr_block_num = curr_block_num
+ self.curr_diff = curr_diff
+ self.check_block = check_block
+ self.stopEvent = stopEvent
+ self.limit = limit
+
+ def run(self):
+ raise NotImplementedError("_SolverBase is an abstract class")
+
+ @staticmethod
+ def create_shared_memory() -> tuple["mp.Array", "mp.Value", "mp.Array"]:
+ """Creates shared memory for the solver processes to use."""
+ curr_block = mp.Array("h", 32, lock=True) # byte array
+ curr_block_num = mp.Value("i", 0, lock=True) # int
+ curr_diff = mp.Array("Q", [0, 0], lock=True) # [high, low]
+
+ return curr_block, curr_block_num, curr_diff
+
+
+class Solver(_SolverBase):
+ def run(self):
+ block_number: int
+ block_and_hotkey_hash_bytes: bytes
+ block_difficulty: int
+ nonce_limit = int(math.pow(2, 64)) - 1
+
+ # Start at random nonce
+ nonce_start = random.randint(0, nonce_limit)
+ nonce_end = nonce_start + self.update_interval
+ while not self.stopEvent.is_set():
+ if self.newBlockEvent.is_set():
+ with self.check_block:
+ block_number = self.curr_block_num.value
+ block_and_hotkey_hash_bytes = bytes(self.curr_block)
+ block_difficulty = _registration_diff_unpack(self.curr_diff)
+
+ self.newBlockEvent.clear()
+
+ # Do a block of nonces
+ solution = _solve_for_nonce_block(
+ nonce_start,
+ nonce_end,
+ block_and_hotkey_hash_bytes,
+ block_difficulty,
+ self.limit,
+ block_number,
+ )
+ if solution is not None:
+ self.solution_queue.put(solution)
+
+ try:
+ # Send time
+ self.finished_queue.put_nowait(self.proc_num)
+ except Full:
+ pass
+
+ nonce_start = random.randint(0, nonce_limit)
+ nonce_start = nonce_start % nonce_limit
+ nonce_end = nonce_start + self.update_interval
+
+
+class CUDASolver(_SolverBase):
+ dev_id: int
+ tpb: int
+
+ def __init__(
+ self,
+ proc_num,
+ num_proc,
+ update_interval,
+ finished_queue,
+ solution_queue,
+ stopEvent,
+ curr_block,
+ curr_block_num,
+ curr_diff,
+ check_block,
+ limit,
+ dev_id: int,
+ tpb: int,
+ ):
+ super().__init__(
+ proc_num,
+ num_proc,
+ update_interval,
+ finished_queue,
+ solution_queue,
+ stopEvent,
+ curr_block,
+ curr_block_num,
+ curr_diff,
+ check_block,
+ limit,
+ )
+ self.dev_id = dev_id
+ self.tpb = tpb
+
+ def run(self):
+ block_number: int = 0 # dummy value
+ block_and_hotkey_hash_bytes: bytes = b"0" * 32 # dummy value
+ block_difficulty: int = int(math.pow(2, 64)) - 1 # dummy value
+ nonce_limit = int(math.pow(2, 64)) - 1 # U64MAX
+
+ # Start at random nonce
+ nonce_start = random.randint(0, nonce_limit)
+ while not self.stopEvent.is_set():
+ if self.newBlockEvent.is_set():
+ with self.check_block:
+ block_number = self.curr_block_num.value
+ block_and_hotkey_hash_bytes = bytes(self.curr_block)
+ block_difficulty = _registration_diff_unpack(self.curr_diff)
+
+ self.newBlockEvent.clear()
+
+ # Do a block of nonces
+ solution = _solve_for_nonce_block_cuda(
+ nonce_start,
+ self.update_interval,
+ block_and_hotkey_hash_bytes,
+ block_difficulty,
+ self.limit,
+ block_number,
+ self.dev_id,
+ self.tpb,
+ )
+ if solution is not None:
+ self.solution_queue.put(solution)
+
+ try:
+ # Signal that a nonce_block was finished using queue
+ # send our proc_num
+ self.finished_queue.put(self.proc_num)
+ except Full:
+ pass
+
+ # increase nonce by number of nonces processed
+ nonce_start += self.update_interval * self.tpb
+ nonce_start = nonce_start % nonce_limit
+
+
+def _solve_for_nonce_block_cuda(
+ nonce_start: int,
+ update_interval: int,
+ block_and_hotkey_hash_bytes: bytes,
+ difficulty: int,
+ limit: int,
+ block_number: int,
+ dev_id: int,
+ tpb: int,
+) -> Optional["POWSolution"]:
+ """Tries to solve the POW on a CUDA device for a block of nonces (nonce_start, nonce_start + update_interval * tpb"""
+ solution, seal = solve_cuda(
+ nonce_start,
+ update_interval,
+ tpb,
+ block_and_hotkey_hash_bytes,
+ difficulty,
+ limit,
+ dev_id,
+ )
+
+ if solution != -1:
+ # Check if solution is valid (i.e. not -1)
+ return POWSolution(solution, block_number, difficulty, seal)
+
+ return None
+
+
+def _solve_for_nonce_block(
+ nonce_start: int,
+ nonce_end: int,
+ block_and_hotkey_hash_bytes: bytes,
+ difficulty: int,
+ limit: int,
+ block_number: int,
+) -> Optional["POWSolution"]:
+ """Tries to solve the POW for a block of nonces (nonce_start, nonce_end)"""
+ for nonce in range(nonce_start, nonce_end):
+ # Create seal.
+ seal = _create_seal_hash(block_and_hotkey_hash_bytes, nonce)
+
+ # Check if seal meets difficulty
+ if _seal_meets_difficulty(seal, difficulty, limit):
+ # Found a solution, save it.
+ return POWSolution(nonce, block_number, difficulty, seal)
+
+ return None
+
+
+def _registration_diff_unpack(packed_diff: "mp.Array") -> int:
+ """Unpacks the packed two 32-bit integers into one 64-bit integer. Little endian."""
+ return int(packed_diff[0] << 32 | packed_diff[1])
+
+
+def _registration_diff_pack(diff: int, packed_diff: "mp.Array"):
+ """Packs the difficulty into two 32-bit integers. Little endian."""
+ packed_diff[0] = diff >> 32
+ packed_diff[1] = diff & 0xFFFFFFFF # low 32 bits
+
+
+def _hash_block_with_hotkey(block_bytes: bytes, hotkey_bytes: bytes) -> bytes:
+ """Hashes the block with the hotkey using Keccak-256 to get 32 bytes"""
+ kec = keccak.new(digest_bits=256)
+ kec = kec.update(bytearray(block_bytes + hotkey_bytes))
+ block_and_hotkey_hash_bytes = kec.digest()
+ return block_and_hotkey_hash_bytes
+
+
+def update_curr_block(
+ curr_diff: "mp.Array",
+ curr_block: "mp.Array",
+ curr_block_num: "mp.Value",
+ block_number: int,
+ block_bytes: bytes,
+ diff: int,
+ hotkey_bytes: bytes,
+ lock: "mp.Lock",
+):
+ """
+ Update the current block data with the provided block information and difficulty.
+
+ This function updates the current block and its difficulty in a thread-safe manner. It sets the current block number, hashes the block with the hotkey, updates the current block bytes, and packs the difficulty.
+
+ Arguments:
+ curr_diff: Shared array to store the current difficulty.
+ curr_block: Shared array to store the current block data.
+ curr_block_num: Shared value to store the current block number.
+ block_number: The block number to set as the current block number.
+ block_bytes: The block data bytes to be hashed with the hotkey.
+ diff: The difficulty value to be packed into the current difficulty array.
+ hotkey_bytes: The hotkey bytes used for hashing the block.
+ lock: A lock to ensure thread-safe updates.
+ """
+ with lock:
+ curr_block_num.value = block_number
+ # Hash the block with the hotkey
+ block_and_hotkey_hash_bytes = _hash_block_with_hotkey(block_bytes, hotkey_bytes)
+ for i in range(32):
+ curr_block[i] = block_and_hotkey_hash_bytes[i]
+ _registration_diff_pack(diff, curr_diff)
+
+
+def get_cpu_count() -> int:
+ """Returns the number of CPUs in the system."""
+ try:
+ return len(os.sched_getaffinity(0))
+ except AttributeError:
+ # macOS does not have sched_getaffinity
+ return os.cpu_count()
+
+
+@dataclass
+class RegistrationStatistics:
+ """Statistics for a registration."""
+
+ time_spent_total: float
+ rounds_total: int
+ time_average: float
+ time_spent: float
+ hash_rate_perpetual: float
+ hash_rate: float
+ difficulty: int
+ block_number: int
+ block_hash: str
+
+
+class RegistrationStatisticsLogger:
+ """Logs statistics for a registration."""
+
+ status: Optional[rich_status.Status]
+
+ def __init__(
+ self,
+ console: Optional[rich_console.Console] = None,
+ output_in_place: bool = True,
+ ) -> None:
+ if console is None:
+ console = Console()
+
+ self.console = console
+
+ if output_in_place:
+ self.status = self.console.status("Solving")
+ else:
+ self.status = None
+
+ def start(self) -> None:
+ if self.status is not None:
+ self.status.start()
+
+ def stop(self) -> None:
+ if self.status is not None:
+ self.status.stop()
+
+ @classmethod
+ def get_status_message(
+ cls, stats: "RegistrationStatistics", verbose: bool = False
+ ) -> str:
+ """Generates the status message based on registration statistics."""
+ message = (
+ "Solving\n"
+ + f"Time Spent (total): [bold white]{timedelta(seconds=stats.time_spent_total)}[/bold white]\n"
+ + (
+ f"Time Spent This Round: {timedelta(seconds=stats.time_spent)}\n"
+ + f"Time Spent Average: {timedelta(seconds=stats.time_average)}\n"
+ if verbose
+ else ""
+ )
+ + f"Registration Difficulty: [bold white]{millify(stats.difficulty)}[/bold white]\n"
+ + f"Iters (Inst/Perp): [bold white]{get_human_readable(stats.hash_rate, 'H')}/s / "
+ + f"{get_human_readable(stats.hash_rate_perpetual, 'H')}/s[/bold white]\n"
+ + f"Block Number: [bold white]{stats.block_number}[/bold white]\n"
+ + f"Block Hash: [bold white]{stats.block_hash.encode('utf-8')}[/bold white]\n"
+ )
+ return message
+
+ def update(self, stats: "RegistrationStatistics", verbose: bool = False) -> None:
+ if self.status is not None:
+ self.status.update(self.get_status_message(stats, verbose=verbose))
+ else:
+ self.console.log(self.get_status_message(stats, verbose=verbose))
+
+
+def _solve_for_difficulty_fast(
+ subtensor: "Subtensor",
+ wallet: "Wallet",
+ netuid: int,
+ output_in_place: bool = True,
+ num_processes: Optional[int] = None,
+ update_interval: Optional[int] = None,
+ n_samples: int = 10,
+ alpha_: float = 0.80,
+ log_verbose: bool = False,
+) -> Optional[POWSolution]:
+ """
+ Solves the POW for registration using multiprocessing.
+
+ Args:
+ subtensor (bittensor.core.subtensor.Subtensor): Subtensor instance to connect to for block information and to submit.
+ wallet (bittensor_wallet.Wallet): wallet to use for registration.
+ netuid (int): The netuid of the subnet to register to.
+ output_in_place (bool): If true, prints the status in place. Otherwise, prints the status on a new line.
+ num_processes (int): Number of processes to use.
+ update_interval (int): Number of nonces to solve before updating block information.
+ n_samples (int): The number of samples of the hash_rate to keep for the EWMA.
+ alpha_ (float): The alpha for the EWMA for the hash_rate calculation.
+ log_verbose (bool): If true, prints more verbose logging of the registration metrics.
+
+ Note: The hash rate is calculated as an exponentially weighted moving average in order to make the measure more robust.
+ Note: We can also modify the update interval to do smaller blocks of work, while still updating the block information after a different number of nonces, to increase the transparency of the process while still keeping the speed.
+ """
+ if num_processes is None:
+ # get the number of allowed processes for this process
+ num_processes = min(1, get_cpu_count())
+
+ if update_interval is None:
+ update_interval = 50_000
+
+ limit = int(math.pow(2, 256)) - 1
+
+ curr_block, curr_block_num, curr_diff = Solver.create_shared_memory()
+
+ # Establish communication queues
+ # See the Solver class for more information on the queues.
+ stopEvent = mp.Event()
+ stopEvent.clear()
+
+ solution_queue = mp.Queue()
+ finished_queues = [mp.Queue() for _ in range(num_processes)]
+ check_block = mp.Lock()
+
+ hotkey_bytes = (
+ wallet.coldkeypub.public_key if netuid == -1 else wallet.hotkey.public_key
+ )
+ # Start consumers
+ solvers = [
+ Solver(
+ i,
+ num_processes,
+ update_interval,
+ finished_queues[i],
+ solution_queue,
+ stopEvent,
+ curr_block,
+ curr_block_num,
+ curr_diff,
+ check_block,
+ limit,
+ )
+ for i in range(num_processes)
+ ]
+
+ # Get first block
+ block_number, difficulty, block_hash = _get_block_with_retry(
+ subtensor=subtensor, netuid=netuid
+ )
+
+ block_bytes = bytes.fromhex(block_hash[2:])
+ old_block_number = block_number
+ # Set to current block
+ update_curr_block(
+ curr_diff,
+ curr_block,
+ curr_block_num,
+ block_number,
+ block_bytes,
+ difficulty,
+ hotkey_bytes,
+ check_block,
+ )
+
+ # Set new block events for each solver to start at the initial block
+ for worker in solvers:
+ worker.newBlockEvent.set()
+
+ for worker in solvers:
+ worker.start() # start the solver processes
+
+ start_time = time.time() # time that the registration started
+ time_last = start_time # time that the last work blocks completed
+
+ curr_stats = RegistrationStatistics(
+ time_spent_total=0.0,
+ time_average=0.0,
+ rounds_total=0,
+ time_spent=0.0,
+ hash_rate_perpetual=0.0,
+ hash_rate=0.0,
+ difficulty=difficulty,
+ block_number=block_number,
+ block_hash=block_hash,
+ )
+
+ start_time_perpetual = time.time()
+
+ logger = RegistrationStatisticsLogger(output_in_place=output_in_place)
+ logger.start()
+
+ solution = None
+
+ hash_rates = [0] * n_samples # The last n true hash_rates
+ weights = [alpha_**i for i in range(n_samples)] # weights decay by alpha
+
+ while netuid == -1 or not subtensor.is_hotkey_registered(
+ netuid=netuid, hotkey_ss58=wallet.hotkey.ss58_address
+ ):
+ # Wait until a solver finds a solution
+ try:
+ solution = solution_queue.get(block=True, timeout=0.25)
+ if solution is not None:
+ break
+ except Empty:
+ # No solution found, try again
+ pass
+
+ # check for new block
+ old_block_number = _check_for_newest_block_and_update(
+ subtensor=subtensor,
+ netuid=netuid,
+ hotkey_bytes=hotkey_bytes,
+ old_block_number=old_block_number,
+ curr_diff=curr_diff,
+ curr_block=curr_block,
+ curr_block_num=curr_block_num,
+ curr_stats=curr_stats,
+ update_curr_block_=update_curr_block,
+ check_block=check_block,
+ solvers=solvers,
+ )
+
+ num_time = 0
+ for finished_queue in finished_queues:
+ try:
+ finished_queue.get(timeout=0.1)
+ num_time += 1
+
+ except Empty:
+ continue
+
+ time_now = time.time() # get current time
+ time_since_last = time_now - time_last # get time since last work block(s)
+ if num_time > 0 and time_since_last > 0.0:
+ # create EWMA of the hash_rate to make measure more robust
+
+ hash_rate_ = (num_time * update_interval) / time_since_last
+ hash_rates.append(hash_rate_)
+ hash_rates.pop(0) # remove the 0th data point
+ curr_stats.hash_rate = sum(
+ [hash_rates[i] * weights[i] for i in range(n_samples)]
+ ) / (sum(weights))
+
+ # update time last to now
+ time_last = time_now
+
+ curr_stats.time_average = (
+ curr_stats.time_average * curr_stats.rounds_total
+ + curr_stats.time_spent
+ ) / (curr_stats.rounds_total + num_time)
+ curr_stats.rounds_total += num_time
+
+ # Update stats
+ curr_stats.time_spent = time_since_last
+ new_time_spent_total = time_now - start_time_perpetual
+ curr_stats.hash_rate_perpetual = (
+ curr_stats.rounds_total * update_interval
+ ) / new_time_spent_total
+ curr_stats.time_spent_total = new_time_spent_total
+
+ # Update the logger
+ logger.update(curr_stats, verbose=log_verbose)
+
+ # exited while, solution contains the nonce or wallet is registered
+ stopEvent.set() # stop all other processes
+ logger.stop()
+
+ # terminate and wait for all solvers to exit
+ terminate_workers_and_wait_for_exit(solvers)
+
+ return solution
+
+
+def _get_block_with_retry(subtensor: "Subtensor", netuid: int) -> tuple[int, int, str]:
+ """
+ Gets the current block number, difficulty, and block hash from the substrate node.
+
+ Args:
+ subtensor (bittensor.core.subtensor.Subtensor): The subtensor object to use to get the block number, difficulty, and block hash.
+ netuid (int): The netuid of the network to get the block number, difficulty, and block hash from.
+
+ Returns:
+ tuple[int, int, bytes]
+ block_number (int): The current block number.
+ difficulty (int): The current difficulty of the subnet.
+ block_hash (bytes): The current block hash.
+
+ Raises:
+ Exception: If the block hash is None.
+ ValueError: If the difficulty is None.
+ """
+ block_number = subtensor.get_current_block()
+ difficulty = 1_000_000 if netuid == -1 else subtensor.difficulty(netuid=netuid)
+ block_hash = subtensor.get_block_hash(block_number)
+ if block_hash is None:
+ raise Exception(
+ "Network error. Could not connect to substrate to get block hash"
+ )
+ if difficulty is None:
+ raise ValueError("Chain error. Difficulty is None")
+ return block_number, difficulty, block_hash
+
+
+def _check_for_newest_block_and_update(
+ subtensor: "Subtensor",
+ netuid: int,
+ old_block_number: int,
+ hotkey_bytes: bytes,
+ curr_diff: "mp.Array",
+ curr_block: "mp.Array",
+ curr_block_num: "mp.Value",
+ update_curr_block_: "Callable",
+ check_block: "mp.Lock",
+ solvers: Union[list["Solver"], list["CUDASolver"]],
+ curr_stats: "RegistrationStatistics",
+) -> int:
+ """
+ Checks for a new block and updates the current block information if a new block is found.
+
+ Args:
+ subtensor (bittensor.core.subtensor.Subtensor): The subtensor object to use for getting the current block.
+ netuid (int): The netuid to use for retrieving the difficulty.
+ old_block_number (int): The old block number to check against.
+ hotkey_bytes (bytes): The bytes of the hotkey's pubkey.
+ curr_diff (multiprocessing.Array): The current difficulty as a multiprocessing array.
+ curr_block (multiprocessing.Array): Where the current block is stored as a multiprocessing array.
+ curr_block_num (multiprocessing.Value): Where the current block number is stored as a multiprocessing value.
+ update_curr_block_ (typing.Callable): A function that updates the current block.
+ check_block (multiprocessing.Lock): A mp lock that is used to check for a new block.
+ solvers (list[bittensor.utils.registration.Solver]): A list of solvers to update the current block for.
+ curr_stats (bittensor.utils.registration.RegistrationStatistics): The current registration statistics to update.
+
+ Returns:
+ (int) The current block number.
+ """
+ block_number = subtensor.get_current_block()
+ if block_number != old_block_number:
+ old_block_number = block_number
+ # update block information
+ block_number, difficulty, block_hash = _get_block_with_retry(
+ subtensor=subtensor, netuid=netuid
+ )
+ block_bytes = bytes.fromhex(block_hash[2:])
+
+ update_curr_block_(
+ curr_diff,
+ curr_block,
+ curr_block_num,
+ block_number,
+ block_bytes,
+ difficulty,
+ hotkey_bytes,
+ check_block,
+ )
+ # Set new block events for each solver
+
+ for worker in solvers:
+ worker.newBlockEvent.set()
+
+ # update stats
+ curr_stats.block_number = block_number
+ curr_stats.block_hash = block_hash
+ curr_stats.difficulty = difficulty
+
+ return old_block_number
+
+
+def _solve_for_difficulty_fast_cuda(
+ subtensor: "Subtensor",
+ wallet: "Wallet",
+ netuid: int,
+ output_in_place: bool = True,
+ update_interval: int = 50_000,
+ tpb: int = 512,
+ dev_id: Union[list[int], int] = 0,
+ n_samples: int = 10,
+ alpha_: float = 0.80,
+ log_verbose: bool = False,
+) -> Optional["POWSolution"]:
+ """
+ Solves the registration fast using CUDA.
+
+ Args:
+ subtensor (bittensor.core.subtensor.Subtensor): The subtensor node to grab blocks.
+ wallet (bittensor_wallet.Wallet): The wallet to register.
+ netuid (int): The netuid of the subnet to register to.
+ output_in_place (bool) If true, prints the output in place, otherwise prints to new lines.
+ update_interval (int): The number of nonces to try before checking for more blocks.
+ tpb (int): The number of threads per block. CUDA param that should match the GPU capability
+ dev_id (Union[list[int], int]): The CUDA device IDs to execute the registration on, either a single device or a list of devices.
+ n_samples (int): The number of samples of the hash_rate to keep for the EWMA.
+ alpha_ (float): The alpha for the EWMA for the hash_rate calculation.
+ log_verbose (bool): If true, prints more verbose logging of the registration metrics.
+
+ Note: The hash rate is calculated as an exponentially weighted moving average in order to make the measure more robust.
+ """
+ if isinstance(dev_id, int):
+ dev_id = [dev_id]
+ elif dev_id is None:
+ dev_id = [0]
+
+ if update_interval is None:
+ update_interval = 50_000
+
+ if not torch.cuda.is_available():
+ raise Exception("CUDA not available")
+
+ limit = int(math.pow(2, 256)) - 1
+
+ # Set mp start to use spawn so CUDA doesn't complain
+ with UsingSpawnStartMethod(force=True):
+ curr_block, curr_block_num, curr_diff = CUDASolver.create_shared_memory()
+
+ # Create a worker per CUDA device
+ num_processes = len(dev_id)
+
+ # Establish communication queues
+ stopEvent = mp.Event()
+ stopEvent.clear()
+ solution_queue = mp.Queue()
+ finished_queues = [mp.Queue() for _ in range(num_processes)]
+ check_block = mp.Lock()
+
+ hotkey_bytes = wallet.hotkey.public_key
+ # Start workers
+ solvers = [
+ CUDASolver(
+ i,
+ num_processes,
+ update_interval,
+ finished_queues[i],
+ solution_queue,
+ stopEvent,
+ curr_block,
+ curr_block_num,
+ curr_diff,
+ check_block,
+ limit,
+ dev_id[i],
+ tpb,
+ )
+ for i in range(num_processes)
+ ]
+
+ # Get first block
+ block_number, difficulty, block_hash = _get_block_with_retry(
+ subtensor=subtensor, netuid=netuid
+ )
+
+ block_bytes = bytes.fromhex(block_hash[2:])
+ old_block_number = block_number
+
+ # Set to current block
+ update_curr_block(
+ curr_diff,
+ curr_block,
+ curr_block_num,
+ block_number,
+ block_bytes,
+ difficulty,
+ hotkey_bytes,
+ check_block,
+ )
+
+ # Set new block events for each solver to start at the initial block
+ for worker in solvers:
+ worker.newBlockEvent.set()
+
+ for worker in solvers:
+ worker.start() # start the solver processes
+
+ start_time = time.time() # time that the registration started
+ time_last = start_time # time that the last work blocks completed
+
+ curr_stats = RegistrationStatistics(
+ time_spent_total=0.0,
+ time_average=0.0,
+ rounds_total=0,
+ time_spent=0.0,
+ hash_rate_perpetual=0.0,
+ hash_rate=0.0, # EWMA hash_rate (H/s)
+ difficulty=difficulty,
+ block_number=block_number,
+ block_hash=block_hash,
+ )
+
+ start_time_perpetual = time.time()
+
+ logger = RegistrationStatisticsLogger(output_in_place=output_in_place)
+ logger.start()
+
+ hash_rates = [0] * n_samples # The last n true hash_rates
+ weights = [alpha_**i for i in range(n_samples)] # weights decay by alpha
+
+ solution = None
+ while netuid == -1 or not subtensor.is_hotkey_registered(
+ netuid=netuid, hotkey_ss58=wallet.hotkey.ss58_address
+ ):
+ # Wait until a solver finds a solution
+ try:
+ solution = solution_queue.get(block=True, timeout=0.15)
+ if solution is not None:
+ break
+ except Empty:
+ # No solution found, try again
+ pass
+
+ # check for new block
+ old_block_number = _check_for_newest_block_and_update(
+ subtensor=subtensor,
+ netuid=netuid,
+ hotkey_bytes=hotkey_bytes,
+ curr_diff=curr_diff,
+ curr_block=curr_block,
+ curr_block_num=curr_block_num,
+ old_block_number=old_block_number,
+ curr_stats=curr_stats,
+ update_curr_block_=update_curr_block,
+ check_block=check_block,
+ solvers=solvers,
+ )
+
+ num_time = 0
+ # Get times for each solver
+ for finished_queue in finished_queues:
+ try:
+ finished_queue.get(timeout=0.1)
+ num_time += 1
+
+ except Empty:
+ continue
+
+ time_now = time.time() # get current time
+ time_since_last = time_now - time_last # get time since last work block(s)
+ if num_time > 0 and time_since_last > 0.0:
+ # create EWMA of the hash_rate to make measure more robust
+
+ hash_rate_ = (num_time * tpb * update_interval) / time_since_last
+ hash_rates.append(hash_rate_)
+ hash_rates.pop(0) # remove the 0th data point
+ curr_stats.hash_rate = sum(
+ [hash_rates[i] * weights[i] for i in range(n_samples)]
+ ) / (sum(weights))
+
+ # update time last to now
+ time_last = time_now
+
+ curr_stats.time_average = (
+ curr_stats.time_average * curr_stats.rounds_total
+ + curr_stats.time_spent
+ ) / (curr_stats.rounds_total + num_time)
+ curr_stats.rounds_total += num_time
+
+ # Update stats
+ curr_stats.time_spent = time_since_last
+ new_time_spent_total = time_now - start_time_perpetual
+ curr_stats.hash_rate_perpetual = (
+ curr_stats.rounds_total * (tpb * update_interval)
+ ) / new_time_spent_total
+ curr_stats.time_spent_total = new_time_spent_total
+
+ # Update the logger
+ logger.update(curr_stats, verbose=log_verbose)
+
+ # exited while, found_solution contains the nonce or wallet is registered
+
+ stopEvent.set() # stop all other processes
+ logger.stop()
+
+ # terminate and wait for all solvers to exit
+ terminate_workers_and_wait_for_exit(solvers)
+
+ return solution
+
+
+def terminate_workers_and_wait_for_exit(
+ workers: list[Union[mp.Process, QueueType]],
+) -> None:
+ for worker in workers:
+ if isinstance(worker, QueueType):
+ worker.join_thread()
+ else:
+ try:
+ worker.join(3.0)
+ except subprocess.TimeoutExpired:
+ worker.terminate()
+ try:
+ worker.close()
+ except ValueError:
+ worker.terminate()
+
+
+def create_pow(
+ subtensor: "Subtensor",
+ wallet: "Wallet",
+ netuid: int,
+ output_in_place: bool = True,
+ cuda: bool = False,
+ dev_id: Union[list[int], int] = 0,
+ tpb: int = 256,
+ num_processes: Optional[int] = None,
+ update_interval: Optional[int] = None,
+ log_verbose: bool = False,
+) -> Optional[dict[str, Any]]:
+ """
+ Creates a proof of work for the given subtensor and wallet.
+
+ Args:
+ subtensor (bittensor.core.subtensor.Subtensor): The subtensor to create a proof of work for.
+ wallet (bittensor_wallet.Wallet): The wallet to create a proof of work for.
+ netuid (int): The netuid for the subnet to create a proof of work for.
+ output_in_place (bool): If true, prints the progress of the proof of work to the console in-place. Meaning the progress is printed on the same lines. Default is ``True``.
+ cuda (bool): If true, uses CUDA to solve the proof of work. Default is ``False``.
+ dev_id (Union[List[int], int]): The CUDA device id(s) to use. If cuda is true and dev_id is a list, then multiple CUDA devices will be used to solve the proof of work. Default is ``0``.
+ tpb (int): The number of threads per block to use when solving the proof of work. Should be a multiple of 32. Default is ``256``.
+ num_processes (Optional[int]): The number of processes to use when solving the proof of work. If None, then the number of processes is equal to the number of CPU cores. Default is None.
+ update_interval (Optional[int]): The number of nonces to run before checking for a new block. Default is ``None``.
+ log_verbose (bool): If true, prints the progress of the proof of work more verbosely. Default is ``False``.
+
+ Returns:
+ Optional[Dict[str, Any]]: The proof of work solution or None if the wallet is already registered or there is a different error.
+
+ Raises:
+ ValueError: If the subnet does not exist.
+ """
+ if netuid != -1:
+ if not subtensor.subnet_exists(netuid=netuid):
+ raise ValueError(f"Subnet {netuid} does not exist.")
+
+ if cuda:
+ solution: Optional[POWSolution] = _solve_for_difficulty_fast_cuda(
+ subtensor,
+ wallet,
+ netuid=netuid,
+ output_in_place=output_in_place,
+ dev_id=dev_id,
+ tpb=tpb,
+ update_interval=update_interval,
+ log_verbose=log_verbose,
+ )
+ else:
+ solution: Optional[POWSolution] = _solve_for_difficulty_fast(
+ subtensor,
+ wallet,
+ netuid=netuid,
+ output_in_place=output_in_place,
+ num_processes=num_processes,
+ update_interval=update_interval,
+ log_verbose=log_verbose,
+ )
+ return solution
diff --git a/bittensor/utils/register_cuda.py b/bittensor/utils/registration/register_cuda.py
similarity index 54%
rename from bittensor/utils/register_cuda.py
rename to bittensor/utils/registration/register_cuda.py
index e0a77f19c9..756250f068 100644
--- a/bittensor/utils/register_cuda.py
+++ b/bittensor/utils/registration/register_cuda.py
@@ -1,24 +1,8 @@
-# The MIT License (MIT)
-# Copyright © 2024 Opentensor Foundation
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
-# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
-# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
-# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
-# the Software.
-#
-# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
-# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
-# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
-# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
-# DEALINGS IN THE SOFTWARE.
+"""This module provides functions for solving Proof of Work (PoW) problems using CUDA."""
import binascii
import hashlib
import io
-import math
from contextlib import redirect_stdout
from typing import Any, Union
@@ -26,6 +10,38 @@
from Crypto.Hash import keccak
+def _hex_bytes_to_u8_list(hex_bytes: bytes) -> list[int]:
+ """
+ Convert a sequence of bytes in hexadecimal format to a list of
+ unsigned 8-bit integers.
+
+ Args:
+ hex_bytes (bytes): A sequence of bytes in hexadecimal format.
+
+ Returns:
+ A list of unsigned 8-bit integers.
+
+ """
+ return [int(hex_bytes[i : i + 2], 16) for i in range(0, len(hex_bytes), 2)]
+
+
+def _create_seal_hash(block_and_hotkey_hash_hex_: bytes, nonce: int) -> bytes:
+ """Creates a seal hash from the block and hotkey hash and nonce."""
+ nonce_bytes = binascii.hexlify(nonce.to_bytes(8, "little"))
+ pre_seal = nonce_bytes + block_and_hotkey_hash_hex_
+ seal_sh256 = hashlib.sha256(bytearray(_hex_bytes_to_u8_list(pre_seal))).digest()
+ kec = keccak.new(digest_bits=256)
+ return kec.update(seal_sh256).digest()
+
+
+def _seal_meets_difficulty(seal_: bytes, difficulty: int, limit: int) -> bool:
+ """Checks if the seal meets the given difficulty."""
+ seal_number = int.from_bytes(seal_, "big")
+ product = seal_number * difficulty
+ # limit = int(math.pow(2, 256)) - 1
+ return product < limit
+
+
def solve_cuda(
nonce_start: "np.int64",
update_interval: "np.int64",
@@ -62,29 +78,6 @@ def solve_cuda(
upper_bytes = upper.to_bytes(32, byteorder="little", signed=False)
- def _hex_bytes_to_u8_list(hex_bytes: bytes):
- """Converts a sequence of hex bytes to a list of unsigned 8-bit integers."""
- hex_chunks = [
- int(hex_bytes[i : i + 2], 16) for i in range(0, len(hex_bytes), 2)
- ]
- return hex_chunks
-
- def _create_seal_hash(block_and_hotkey_hash_hex_: bytes, nonce: int) -> bytes:
- """Creates a seal hash from the block and hotkey hash and nonce."""
- nonce_bytes = binascii.hexlify(nonce.to_bytes(8, "little"))
- pre_seal = nonce_bytes + block_and_hotkey_hash_hex_
- seal_sh256 = hashlib.sha256(bytearray(_hex_bytes_to_u8_list(pre_seal))).digest()
- kec = keccak.new(digest_bits=256)
- return kec.update(seal_sh256).digest()
-
- def _seal_meets_difficulty(seal_: bytes, difficulty_: int):
- """Checks if the seal meets the given difficulty."""
- seal_number = int.from_bytes(seal_, "big")
- product = seal_number * difficulty_
- limit_ = int(math.pow(2, 256)) - 1
-
- return product < limit_
-
# Call cython function
# int blockSize, uint64 nonce_start, uint64 update_interval, const unsigned char[:] limit,
# const unsigned char[:] block_bytes, int dev_id
@@ -101,7 +94,7 @@ def _seal_meets_difficulty(seal_: bytes, difficulty_: int):
seal = None
if solution != -1:
seal = _create_seal_hash(block_and_hotkey_hash_hex, solution)
- if _seal_meets_difficulty(seal, difficulty):
+ if _seal_meets_difficulty(seal, difficulty, limit):
return solution, seal
else:
return -1, b"\x00" * 32
diff --git a/tests/unit_tests/extrinsics/test_init.py b/tests/unit_tests/extrinsics/test__init__.py
similarity index 100%
rename from tests/unit_tests/extrinsics/test_init.py
rename to tests/unit_tests/extrinsics/test__init__.py
diff --git a/tests/unit_tests/extrinsics/test_registration.py b/tests/unit_tests/extrinsics/test_registration.py
index 18d14fac10..18676619de 100644
--- a/tests/unit_tests/extrinsics/test_registration.py
+++ b/tests/unit_tests/extrinsics/test_registration.py
@@ -91,7 +91,7 @@ def test_register_extrinsic_without_pow(
),
mocker.patch("torch.cuda.is_available", return_value=cuda_available),
mocker.patch(
- "bittensor.utils.registration._get_block_with_retry",
+ "bittensor.utils.registration.pow._get_block_with_retry",
return_value=(0, 0, "00ff11ee"),
),
):
@@ -142,10 +142,10 @@ def test_register_extrinsic_with_pow(
):
# Arrange
with mocker.patch(
- "bittensor.utils.registration._solve_for_difficulty_fast",
+ "bittensor.utils.registration.pow._solve_for_difficulty_fast",
return_value=mock_pow_solution if pow_success else None,
), mocker.patch(
- "bittensor.utils.registration._solve_for_difficulty_fast_cuda",
+ "bittensor.utils.registration.pow._solve_for_difficulty_fast_cuda",
return_value=mock_pow_solution if pow_success else None,
), mocker.patch(
"bittensor.core.extrinsics.registration._do_pow_register",
diff --git a/tests/unit_tests/test_async_subtensor.py b/tests/unit_tests/test_async_subtensor.py
index 87813761dc..c90309e808 100644
--- a/tests/unit_tests/test_async_subtensor.py
+++ b/tests/unit_tests/test_async_subtensor.py
@@ -1073,6 +1073,126 @@ async def test_neurons_lite(subtensor, mocker, fake_hex_bytes_result, response):
assert result == []
+@pytest.mark.asyncio
+async def test_get_neuron_for_pubkey_and_subnet_success(subtensor, mocker):
+ """Tests successful retrieval of neuron information."""
+ # Preps
+ fake_hotkey = "fake_ss58_address"
+ fake_netuid = 1
+ fake_uid = 123
+ fake_result = b"fake_neuron_data"
+
+ mocker.patch.object(
+ subtensor.substrate,
+ "query",
+ return_value=fake_uid,
+ )
+ mocker.patch.object(
+ subtensor.substrate,
+ "rpc_request",
+ return_value={"result": fake_result},
+ )
+ mocked_neuron_info = mocker.patch.object(
+ async_subtensor.NeuronInfo, "from_vec_u8", return_value="fake_neuron_info"
+ )
+
+ # Call
+ result = await subtensor.get_neuron_for_pubkey_and_subnet(
+ hotkey_ss58=fake_hotkey, netuid=fake_netuid
+ )
+
+ # Asserts
+ subtensor.substrate.query.assert_awaited_once()
+ subtensor.substrate.query.assert_called_once_with(
+ module="SubtensorModule",
+ storage_function="Uids",
+ params=[fake_netuid, fake_hotkey],
+ block_hash=None,
+ reuse_block_hash=False,
+ )
+ subtensor.substrate.rpc_request.assert_awaited_once()
+ subtensor.substrate.rpc_request.assert_called_once_with(
+ method="neuronInfo_getNeuron", params=[fake_netuid, fake_uid]
+ )
+ mocked_neuron_info.assert_called_once_with(fake_result)
+ assert result == "fake_neuron_info"
+
+
+@pytest.mark.asyncio
+async def test_get_neuron_for_pubkey_and_subnet_uid_not_found(subtensor, mocker):
+ """Tests the case where UID is not found."""
+ # Preps
+ fake_hotkey = "fake_ss58_address"
+ fake_netuid = 1
+
+ mocker.patch.object(
+ subtensor.substrate,
+ "query",
+ return_value=None,
+ )
+ mocked_get_null_neuron = mocker.patch.object(
+ async_subtensor.NeuronInfo, "get_null_neuron", return_value="null_neuron"
+ )
+
+ # Call
+ result = await subtensor.get_neuron_for_pubkey_and_subnet(
+ hotkey_ss58=fake_hotkey, netuid=fake_netuid
+ )
+
+ # Asserts
+ subtensor.substrate.query.assert_called_once_with(
+ module="SubtensorModule",
+ storage_function="Uids",
+ params=[fake_netuid, fake_hotkey],
+ block_hash=None,
+ reuse_block_hash=False,
+ )
+ mocked_get_null_neuron.assert_called_once()
+ assert result == "null_neuron"
+
+
+@pytest.mark.asyncio
+async def test_get_neuron_for_pubkey_and_subnet_rpc_result_empty(subtensor, mocker):
+ """Tests the case where RPC result is empty."""
+ # Preps
+ fake_hotkey = "fake_ss58_address"
+ fake_netuid = 1
+ fake_uid = 123
+
+ mocker.patch.object(
+ subtensor.substrate,
+ "query",
+ return_value=fake_uid,
+ )
+ mocker.patch.object(
+ subtensor.substrate,
+ "rpc_request",
+ return_value={"result": None},
+ )
+ mocked_get_null_neuron = mocker.patch.object(
+ async_subtensor.NeuronInfo, "get_null_neuron", return_value="null_neuron"
+ )
+
+ # Call
+ result = await subtensor.get_neuron_for_pubkey_and_subnet(
+ hotkey_ss58=fake_hotkey, netuid=fake_netuid
+ )
+
+ # Asserts
+ subtensor.substrate.query.assert_called_once_with(
+ module="SubtensorModule",
+ storage_function="Uids",
+ params=[fake_netuid, fake_hotkey],
+ block_hash=None,
+ reuse_block_hash=False,
+ )
+ subtensor.substrate.rpc_request.assert_called_once_with(
+ method="neuronInfo_getNeuron", params=[fake_netuid, fake_uid]
+ )
+ mocked_get_null_neuron.assert_called_once()
+ assert result == "null_neuron"
+
+
@pytest.mark.asyncio
async def test_neuron_for_uid_happy_path(subtensor, mocker):
"""Tests neuron_for_uid method with happy path."""
diff --git a/tests/unit_tests/utils/test_init.py b/tests/unit_tests/utils/test_init.py
deleted file mode 100644
index fbbc8d5bc9..0000000000
--- a/tests/unit_tests/utils/test_init.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import pytest
-
-from bittensor import warnings, __getattr__, version_split, logging, trace, debug
-
-
-def test_getattr_version_split():
- """Test that __getattr__ for 'version_split' issues a deprecation warning and returns the correct value."""
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always")
- assert __getattr__("version_split") == version_split
- assert len(w) == 1
- assert issubclass(w[-1].category, DeprecationWarning)
- assert "version_split is deprecated" in str(w[-1].message)
-
-
-@pytest.mark.parametrize("test_input, expected", [(True, "Trace"), (False, "Default")])
-def test_trace(test_input, expected):
- """Test the trace function turns tracing on|off."""
- trace(test_input)
- assert logging.current_state_value == expected
-
-
-@pytest.mark.parametrize("test_input, expected", [(True, "Debug"), (False, "Default")])
-def test_debug(test_input, expected):
- """Test the debug function turns tracing on|off."""
- debug(test_input)
- assert logging.current_state_value == expected
diff --git a/tests/unit_tests/utils/test_registration.py b/tests/unit_tests/utils/test_registration.py
index c85608b5f3..ccef37dfb2 100644
--- a/tests/unit_tests/utils/test_registration.py
+++ b/tests/unit_tests/utils/test_registration.py
@@ -31,7 +31,7 @@ def error(self, message):
@pytest.fixture
def mock_bittensor_logging(monkeypatch):
mock_logger = MockBittensorLogging()
- monkeypatch.setattr("bittensor.utils.registration.logging", mock_logger)
+ monkeypatch.setattr("bittensor.utils.registration.pow.logging", mock_logger)
return mock_logger
@@ -48,7 +48,9 @@ def test_lazy_loaded_torch__torch_installed(monkeypatch, mock_bittensor_logging)
def test_lazy_loaded_torch__no_torch(monkeypatch, mock_bittensor_logging):
- monkeypatch.setattr("bittensor.utils.registration._get_real_torch", lambda: None)
+ monkeypatch.setattr(
+ "bittensor.utils.registration.pow._get_real_torch", lambda: None
+ )
torch = LazyLoadedTorch()
diff --git a/tests/unit_tests/utils/test_utils.py b/tests/unit_tests/utils/test_utils.py
index eda2eeb100..469ab3cc3f 100644
--- a/tests/unit_tests/utils/test_utils.py
+++ b/tests/unit_tests/utils/test_utils.py
@@ -15,13 +15,37 @@
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
-from bittensor_wallet import Wallet
import pytest
+from bittensor_wallet import Wallet
-from bittensor import utils
+from bittensor import warnings, __getattr__, version_split, logging, trace, debug, utils
from bittensor.core.settings import SS58_FORMAT
+def test_getattr_version_split():
+ """Test that __getattr__ for 'version_split' issues a deprecation warning and returns the correct value."""
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ assert __getattr__("version_split") == version_split
+ assert len(w) == 1
+ assert issubclass(w[-1].category, DeprecationWarning)
+ assert "version_split is deprecated" in str(w[-1].message)
+
+
+@pytest.mark.parametrize("test_input, expected", [(True, "Trace"), (False, "Default")])
+def test_trace(test_input, expected):
+ """Test the trace function turns tracing on|off."""
+ trace(test_input)
+ assert logging.current_state_value == expected
+
+
+@pytest.mark.parametrize("test_input, expected", [(True, "Debug"), (False, "Default")])
+def test_debug(test_input, expected):
+ """Test the debug function turns tracing on|off."""
+ debug(test_input)
+ assert logging.current_state_value == expected
+
+
def test_ss58_to_vec_u8(mocker):
"""Tests `utils.ss58_to_vec_u8` function."""
# Prep