diff --git a/bittensor/_subtensor/__init__.py b/bittensor/_subtensor/__init__.py index eab11fd1cc..cd60b673ac 100644 --- a/bittensor/_subtensor/__init__.py +++ b/bittensor/_subtensor/__init__.py @@ -192,7 +192,7 @@ def add_args(cls, parser: argparse.ArgumentParser, prefix: str = None ): # registration args. Used for register and re-register and anything that calls register. parser.add_argument('--' + prefix_str + 'subtensor.register.num_processes', '-n', dest=prefix_str + 'subtensor.register.num_processes', help="Number of processors to use for registration", type=int, default=bittensor.defaults.subtensor.register.num_processes) parser.add_argument('--' + prefix_str + 'subtensor.register.update_interval', '--' + prefix_str + 'subtensor.register.cuda.update_interval', '--' + prefix_str + 'cuda.update_interval', '-u', help="The number of nonces to process before checking for next block during registration", type=int, default=bittensor.defaults.subtensor.register.update_interval) - parser.add_argument('--' + prefix_str + 'subtensor.register.output_in_place', help="Whether to ouput the registration statistics in-place. Set flag to enable.", action='store_true', required=False, default=bittensor.defaults.subtensor.register.output_in_place) + parser.add_argument('--' + prefix_str + 'subtensor.register.no_output_in_place', '--' + prefix_str + 'no_output_in_place', dest="subtensor.register.output_in_place", help="Whether to not ouput the registration statistics in-place. Set flag to disable output in-place.", action='store_false', required=False, default=bittensor.defaults.subtensor.register.output_in_place) parser.add_argument('--' + prefix_str + 'subtensor.register.verbose', help="Whether to ouput the registration statistics verbosely.", action='store_true', required=False, default=bittensor.defaults.subtensor.register.verbose) ## Registration args for CUDA registration. diff --git a/bittensor/utils/__init__.py b/bittensor/utils/__init__.py index ff0d9af119..9526ad41fc 100644 --- a/bittensor/utils/__init__.py +++ b/bittensor/utils/__init__.py @@ -8,7 +8,7 @@ import random import time from dataclasses import dataclass -from queue import Empty +from queue import Empty, Full from typing import Any, Dict, List, Optional, Tuple, Union, Callable import backoff @@ -158,8 +158,8 @@ class SolverBase(multiprocessing.Process): The total number of processes running. update_interval: int The number of nonces to try to solve before checking for a new block. - time_queue: multiprocessing.Queue - The queue to put the time the process took to finish each update_interval. + finished_queue: multiprocessing.Queue + The queue to put the process number when a process finishes each update_interval. Used for calculating the average time per update_interval across all processes. solution_queue: multiprocessing.Queue The queue to put the solution the process has found during the pow solve. @@ -192,7 +192,7 @@ class SolverBase(multiprocessing.Process): proc_num: int num_proc: int update_interval: int - time_queue: multiprocessing.Queue + finished_queue: multiprocessing.Queue solution_queue: multiprocessing.Queue newBlockEvent: multiprocessing.Event stopEvent: multiprocessing.Event @@ -202,12 +202,12 @@ class SolverBase(multiprocessing.Process): check_block: multiprocessing.Lock limit: int - def __init__(self, proc_num, num_proc, update_interval, time_queue, solution_queue, stopEvent, curr_block, curr_block_num, curr_diff, check_block, limit): + def __init__(self, proc_num, num_proc, update_interval, finished_queue, solution_queue, stopEvent, curr_block, curr_block_num, curr_diff, check_block, limit): multiprocessing.Process.__init__(self) self.proc_num = proc_num self.num_proc = num_proc self.update_interval = update_interval - self.time_queue = time_queue + self.finished_queue = finished_queue self.solution_queue = solution_queue self.newBlockEvent = multiprocessing.Event() self.newBlockEvent.clear() @@ -239,41 +239,39 @@ def run(self): block_difficulty = registration_diff_unpack(self.curr_diff) self.newBlockEvent.clear() - # reset nonces to start from random point - # prevents the same nonces (for each block) from being tried by multiple processes - # also prevents the same nonces from being tried by multiple peers - nonce_start = random.randint( 0, nonce_limit ) - nonce_end = nonce_start + self.update_interval # Do a block of nonces - solution, time = solve_for_nonce_block(self, nonce_start, nonce_end, block_bytes, block_difficulty, self.limit, block_number) + solution = solve_for_nonce_block(self, nonce_start, nonce_end, block_bytes, block_difficulty, self.limit, block_number) if solution is not None: self.solution_queue.put(solution) - # Send time - self.time_queue.put_nowait(time) + try: + # Send time + self.finished_queue.put_nowait(self.proc_num) + except Full: + pass - nonce_start += self.update_interval * self.num_proc - nonce_end += self.update_interval * self.num_proc + nonce_start = random.randint( 0, nonce_limit ) + nonce_start = nonce_start % nonce_limit + nonce_end = nonce_start + self.update_interval class CUDASolver(SolverBase): dev_id: int TPB: int - def __init__(self, proc_num, num_proc, update_interval, time_queue, solution_queue, stopEvent, curr_block, curr_block_num, curr_diff, check_block, limit, dev_id: int, TPB: int): - super().__init__(proc_num, num_proc, update_interval, time_queue, solution_queue, stopEvent, curr_block, curr_block_num, curr_diff, check_block, limit) + def __init__(self, proc_num, num_proc, update_interval, finished_queue, solution_queue, stopEvent, curr_block, curr_block_num, curr_diff, check_block, limit, dev_id: int, TPB: int): + super().__init__(proc_num, num_proc, update_interval, finished_queue, solution_queue, stopEvent, curr_block, curr_block_num, curr_diff, check_block, limit) self.dev_id = dev_id self.TPB = TPB def run(self): - block_number: int - block_bytes: bytes - block_difficulty: int - nonce_limit = int(math.pow(2,64)) - 1 + block_number: int = 0 # dummy value + block_bytes: bytes = b'0' * 32 # dummy value + block_difficulty: int = int(math.pow(2,64)) - 1 # dummy value + nonce_limit = int(math.pow(2,64)) - 1 # U64MAX # Start at random nonce - nonce_start = self.TPB * self.update_interval * self.proc_num + random.randint( 0, nonce_limit ) - nonce_end = nonce_start + self.update_interval * self.TPB + nonce_start = random.randint( 0, nonce_limit ) while not self.stopEvent.is_set(): if self.newBlockEvent.is_set(): with self.check_block: @@ -282,26 +280,26 @@ def run(self): block_difficulty = registration_diff_unpack(self.curr_diff) self.newBlockEvent.clear() - # reset nonces to start from random point - nonce_start = self.update_interval * self.proc_num + random.randint( 0, nonce_limit ) - nonce_end = nonce_start + self.update_interval # Do a block of nonces - solution, time = solve_for_nonce_block_cuda(self, nonce_start, self.update_interval, block_bytes, block_difficulty, self.limit, block_number, self.dev_id, self.TPB) + solution = solve_for_nonce_block_cuda(self, nonce_start, self.update_interval, block_bytes, block_difficulty, self.limit, block_number, self.dev_id, self.TPB) if solution is not None: self.solution_queue.put(solution) - # Send time - self.time_queue.put_nowait(time) - - nonce_start += self.update_interval * self.num_proc + try: + # Signal that a nonce_block was finished using queue + # send our proc_num + self.finished_queue.put(self.proc_num) + except Full: + pass + + # increase nonce by number of nonces processed + nonce_start += self.update_interval * self.TPB nonce_start = nonce_start % nonce_limit - nonce_end += self.update_interval * self.num_proc - -def solve_for_nonce_block_cuda(solver: CUDASolver, nonce_start: int, update_interval: int, block_bytes: bytes, difficulty: int, limit: int, block_number: int, dev_id: int, TPB: int) -> Tuple[Optional[POWSolution], int]: - start = time.time() +def solve_for_nonce_block_cuda(solver: CUDASolver, nonce_start: int, update_interval: int, block_bytes: bytes, difficulty: int, limit: int, block_number: int, dev_id: int, TPB: int) -> Optional[POWSolution]: + """Tries to solve the POW on a CUDA device for a block of nonces (nonce_start, nonce_start + update_interval * TPB""" solution, seal = solve_cuda(nonce_start, update_interval, TPB, @@ -312,19 +310,14 @@ def solve_for_nonce_block_cuda(solver: CUDASolver, nonce_start: int, update_inte dev_id) if (solution != -1): - # Check if solution is valid - # Attempt to reset CUDA device - #reset_cuda() - - #print(f"{solver.proc_num} on cuda:{solver.dev_id} found a solution: {solution}, {block_number}, {str(block_bytes)}, {str(seal)}, {difficulty}") - # Found a solution, save it. - return POWSolution(solution, block_number, difficulty, seal), time.time() - start + # Check if solution is valid (i.e. not -1) + return POWSolution(solution, block_number, difficulty, seal) - return None, time.time() - start + return None -def solve_for_nonce_block(solver: Solver, nonce_start: int, nonce_end: int, block_bytes: bytes, difficulty: int, limit: int, block_number: int) -> Tuple[Optional[POWSolution], int]: - start = time.time() +def solve_for_nonce_block(solver: Solver, nonce_start: int, nonce_end: int, block_bytes: bytes, difficulty: int, limit: int, block_number: int) -> Optional[POWSolution]: + """Tries to solve the POW for a block of nonces (nonce_start, nonce_end)""" for nonce in range(nonce_start, nonce_end): # Create seal. nonce_bytes = binascii.hexlify(nonce.to_bytes(8, 'little')) @@ -338,9 +331,9 @@ def solve_for_nonce_block(solver: Solver, nonce_start: int, nonce_end: int, bloc product = seal_number * difficulty if product < limit: # Found a solution, save it. - return POWSolution(nonce, block_number, difficulty, seal), time.time() - start + return POWSolution(nonce, block_number, difficulty, seal) - return None, time.time() - start + return None def registration_diff_unpack(packed_diff: multiprocessing.Array) -> int: @@ -353,6 +346,9 @@ def registration_diff_pack(diff: int, packed_diff: multiprocessing.Array): packed_diff[0] = diff >> 32 packed_diff[1] = diff & 0xFFFFFFFF # low 32 bits +def calculate_hash_rate() -> int: + pass + def update_curr_block(curr_diff: multiprocessing.Array, curr_block: multiprocessing.Array, curr_block_num: multiprocessing.Value, block_number: int, block_bytes: bytes, diff: int, lock: multiprocessing.Lock): with lock: @@ -373,7 +369,6 @@ def get_cpu_count(): class RegistrationStatistics: """Statistics for a registration.""" time_spent_total: float - time_average_perpetual: float rounds_total: int time_average: float time_spent: float @@ -411,8 +406,8 @@ def get_status_message(cls, stats: RegistrationStatistics, verbose: bool = False time spent: {timedelta(seconds=stats.time_spent)}""" + \ (f""" time spent total: {stats.time_spent_total:.2f} s - time average perpetual: {timedelta(seconds=stats.time_average_perpetual)} - """ if verbose else "") + f""" + time spent average: {timedelta(seconds=stats.time_average)}""" if verbose else "") + \ + f""" Difficulty: [bold white]{millify(stats.difficulty)}[/bold white] Iters: [bold white]{get_human_readable(int(stats.hash_rate), 'H')}/s[/bold white] Block: [bold white]{stats.block_number}[/bold white] @@ -427,7 +422,7 @@ def update( self, stats: RegistrationStatistics, verbose: bool = False ) -> None self.console.log( self.get_status_message(stats, verbose=verbose), ) -def solve_for_difficulty_fast( subtensor, wallet, output_in_place: bool = True, num_processes: Optional[int] = None, update_interval: Optional[int] = None, log_verbose: bool = False ) -> Optional[POWSolution]: +def solve_for_difficulty_fast( subtensor, wallet, output_in_place: bool = True, num_processes: Optional[int] = None, update_interval: Optional[int] = None, n_samples: int = 5, alpha_: float = 0.70, log_verbose: bool = False ) -> Optional[POWSolution]: """ Solves the POW for registration using multiprocessing. Args: @@ -441,8 +436,13 @@ def solve_for_difficulty_fast( subtensor, wallet, output_in_place: bool = True, Number of processes to use. update_interval: int Number of nonces to solve before updating block information. + n_samples: int + The number of samples of the hash_rate to keep for the EWMA + alpha_: float + The alpha for the EWMA for the hash_rate calculation log_verbose: bool If true, prints more verbose logging of the registration metrics. + Note: The hash rate is calculated as an exponentially weighted moving average in order to make the measure more robust. Note: - We can also modify the update interval to do smaller blocks of work, while still updating the block information after a different number of nonces, @@ -467,11 +467,11 @@ def solve_for_difficulty_fast( subtensor, wallet, output_in_place: bool = True, stopEvent.clear() solution_queue = multiprocessing.Queue() - time_queue = multiprocessing.Queue() + finished_queue = multiprocessing.Queue() check_block = multiprocessing.Lock() # Start consumers - solvers = [ Solver(i, num_processes, update_interval, time_queue, solution_queue, stopEvent, curr_block, curr_block_num, curr_diff, check_block, limit) + solvers = [ Solver(i, num_processes, update_interval, finished_queue, solution_queue, stopEvent, curr_block, curr_block_num, curr_diff, check_block, limit) for i in range(num_processes) ] # Get first block @@ -485,16 +485,18 @@ def solve_for_difficulty_fast( subtensor, wallet, output_in_place: bool = True, # Set to current block update_curr_block(curr_diff, curr_block, curr_block_num, block_number, block_bytes, difficulty, check_block) - # Set new block events for each solver to start - for w in solvers: - w.newBlockEvent.set() + # Set new block events for each solver to start at the initial block + for worker in solvers: + worker.newBlockEvent.set() - for w in solvers: - w.start() # start the solver processes + for worker in solvers: + worker.start() # start the solver processes + + start_time = time.time() # time that the registration started + time_last = start_time # time that the last work blocks completed curr_stats = RegistrationStatistics( time_spent_total = 0.0, - time_average_perpetual = 0.0, time_average = 0.0, rounds_total = 0, time_spent = 0.0, @@ -506,16 +508,19 @@ def solve_for_difficulty_fast( subtensor, wallet, output_in_place: bool = True, ) start_time_perpetual = time.time() + console = bittensor.__console__ logger = RegistrationStatisticsLogger(console, output_in_place) logger.start() solution = None + hash_rate = 0 # EWMA hash_rate (H/s) + hash_rates = [0] * n_samples # The last n true hash_rates + weights = [alpha_ ** i for i in range(n_samples)] # weights decay by alpha + while not wallet.is_registered(subtensor): - start_time = time.time() - time_avg: Optional[float] = None # Wait until a solver finds a solution try: solution = solution_queue.get(block=True, timeout=0.25) @@ -538,34 +543,41 @@ def solve_for_difficulty_fast( subtensor, wallet, output_in_place: bool = True, update_curr_block(curr_diff, curr_block, curr_block_num, block_number, block_bytes, difficulty, check_block) # Set new block events for each solver - for w in solvers: - w.newBlockEvent.set() + for worker in solvers: + worker.newBlockEvent.set() # update stats curr_stats.block_number = block_number curr_stats.block_hash = block_hash curr_stats.difficulty = difficulty - # Get times for each solver - time_total = 0 num_time = 0 - - for _ in solvers: + for _ in range(len(solvers)*2): try: - time_total += time_queue.get_nowait() + proc_num = finished_queue.get(timeout=0.1) num_time += 1 + except Empty: - break + # no more times + continue - # Calculate average time per solver for the update_interval - if num_time > 0: - time_avg = time_total / num_time - curr_stats.hash_rate = update_interval*num_processes / time_avg - - curr_stats.time_spent = time.time() - start_time - new_time_spent_total = time.time() - start_time_perpetual - curr_stats.time_average = time_avg if not None else curr_stats.time_average - curr_stats.time_average_perpetual = (curr_stats.time_average_perpetual*curr_stats.rounds_total + curr_stats.time_spent)/(curr_stats.rounds_total+1) + time_now = time.time() # get current time + time_since_last = time_now - time_last # get time since last work block(s) + if num_time > 0 and time_since_last > 0.0: + # create EWMA of the hash_rate to make measure more robust + + hash_rate_ = (num_time * update_interval) / time_since_last + hash_rates.append(hash_rate_) + hash_rates.pop(0) # remove the 0th data point + curr_stats.hash_rate = sum([hash_rates[i]*weights[i] for i in range(n_samples)])/(sum(weights)) + + # update time last to now + time_last = time_now + + # Update stats + curr_stats.time_spent = time_since_last + new_time_spent_total = time_now - start_time_perpetual + curr_stats.time_average = (curr_stats.time_average*curr_stats.rounds_total + curr_stats.time_spent)/(curr_stats.rounds_total+1) curr_stats.rounds_total += 1 curr_stats.hash_rate_perpetual = (curr_stats.time_spent_total*curr_stats.hash_rate_perpetual + curr_stats.hash_rate)/ new_time_spent_total curr_stats.time_spent_total = new_time_spent_total @@ -577,6 +589,9 @@ def solve_for_difficulty_fast( subtensor, wallet, output_in_place: bool = True, stopEvent.set() # stop all other processes logger.stop() + # terminate and wait for all solvers to exit + terminate_workers_and_wait_for_exit(solvers) + return solution def get_human_readable(num, suffix="H"): @@ -623,7 +638,7 @@ def __exit__(self, *args): multiprocessing.set_start_method(self._old_start_method, force=True) -def solve_for_difficulty_fast_cuda( subtensor: 'bittensor.Subtensor', wallet: 'bittensor.Wallet', output_in_place: bool = True, update_interval: int = 50_000, TPB: int = 512, dev_id: Union[List[int], int] = 0, log_verbose: bool = False ) -> Optional[POWSolution]: +def solve_for_difficulty_fast_cuda( subtensor: 'bittensor.Subtensor', wallet: 'bittensor.Wallet', output_in_place: bool = True, update_interval: int = 50_000, TPB: int = 512, dev_id: Union[List[int], int] = 0, n_samples: int = 5, alpha_: float = 0.70, log_verbose: bool = False ) -> Optional[POWSolution]: """ Solves the registration fast using CUDA Args: @@ -639,8 +654,13 @@ def solve_for_difficulty_fast_cuda( subtensor: 'bittensor.Subtensor', wallet: 'b The number of threads per block. CUDA param that should match the GPU capability dev_id: Union[List[int], int] The CUDA device IDs to execute the registration on, either a single device or a list of devices + n_samples: int + The number of samples of the hash_rate to keep for the EWMA + alpha_: float + The alpha for the EWMA for the hash_rate calculation log_verbose: bool If true, prints more verbose logging of the registration metrics. + Note: The hash rate is calculated as an exponentially weighted moving average in order to make the measure more robust. """ if isinstance(dev_id, int): dev_id = [dev_id] @@ -672,15 +692,17 @@ def update_curr_block(block_number: int, block_bytes: bytes, diff: int, lock: mu stopEvent = multiprocessing.Event() stopEvent.clear() solution_queue = multiprocessing.Queue() - time_queue = multiprocessing.Queue() + finished_queue = multiprocessing.Queue() check_block = multiprocessing.Lock() - # Start consumers + # Start workers + ## Create a worker per CUDA device num_processes = len(dev_id) - ## Create one consumer per GPU - solvers = [ CUDASolver(i, num_processes, update_interval, time_queue, solution_queue, stopEvent, curr_block, curr_block_num, curr_diff, check_block, limit, dev_id[i], TPB) + + solvers = [ CUDASolver(i, num_processes, update_interval, finished_queue, solution_queue, stopEvent, curr_block, curr_block_num, curr_diff, check_block, limit, dev_id[i], TPB) for i in range(num_processes) ] + # Get first block block_number = subtensor.get_current_block() difficulty = subtensor.difficulty @@ -693,21 +715,23 @@ def update_curr_block(block_number: int, block_bytes: bytes, diff: int, lock: mu # Set to current block update_curr_block(block_number, block_bytes, difficulty, check_block) - # Set new block events for each solver to start - for w in solvers: - w.newBlockEvent.set() + # Set new block events for each solver to start at the initial block + for worker in solvers: + worker.newBlockEvent.set() + + for worker in solvers: + worker.start() # start the solver processes - for w in solvers: - w.start() # start the solver processes + start_time = time.time() # time that the registration started + time_last = start_time # time that the last work blocks completed curr_stats = RegistrationStatistics( time_spent_total = 0.0, - time_average_perpetual = 0.0, time_average = 0.0, rounds_total = 0, time_spent = 0.0, hash_rate_perpetual = 0.0, - hash_rate = 0.0, + hash_rate = 0.0, # EWMA hash_rate (H/s) difficulty = difficulty, block_number = block_number, block_hash = block_hash @@ -719,11 +743,10 @@ def update_curr_block(block_number: int, block_bytes: bytes, diff: int, lock: mu logger = RegistrationStatisticsLogger(console, output_in_place) logger.start() - solution = None + hash_rates = [0] * n_samples # The last n true hash_rates + weights = [alpha_ ** i for i in range(n_samples)] # weights decay by alpha while not wallet.is_registered(subtensor): - start_time = time.time() - time_avg: Optional[float] = None # Wait until a solver finds a solution try: solution = solution_queue.get(block=True, timeout=0.15) @@ -744,34 +767,44 @@ def update_curr_block(block_number: int, block_bytes: bytes, diff: int, lock: mu update_curr_block(block_number, block_bytes, difficulty, check_block) # Set new block events for each solver - for w in solvers: - w.newBlockEvent.set() + + for worker in solvers: + worker.newBlockEvent.set() + # update stats curr_stats.block_number = block_number curr_stats.block_hash = block_hash curr_stats.difficulty = difficulty - # Get times for each solver - time_total = 0 num_time = 0 - for _ in solvers: + # Get times for each solver + for _ in range(len(solvers)*2): try: - time_ = time_queue.get(timeout=0.01) - time_total += time_ + proc_num = finished_queue.get(timeout=0.1) num_time += 1 - + except Empty: - break + # no more times + continue - if num_time > 0: - time_avg = time_total / num_time - curr_stats.hash_rate = TPB*update_interval*num_processes / time_avg - - curr_stats.time_spent = time.time() - start_time - new_time_spent_total = time.time() - start_time_perpetual - curr_stats.time_average = time_avg if not None else curr_stats.time_average - curr_stats.time_average_perpetual = (curr_stats.time_average_perpetual*curr_stats.rounds_total + curr_stats.time_spent)/(curr_stats.rounds_total+1) + time_now = time.time() # get current time + time_since_last = time_now - time_last # get time since last work block(s) + if num_time > 0 and time_since_last > 0.0: + # create EWMA of the hash_rate to make measure more robust + + hash_rate_ = (num_time * TPB * update_interval) / time_since_last + hash_rates.append(hash_rate_) + hash_rates.pop(0) # remove the 0th data point + curr_stats.hash_rate = sum([hash_rates[i]*weights[i] for i in range(n_samples)])/(sum(weights)) + + # update time last to now + time_last = time_now + + # Update stats + curr_stats.time_spent = time_since_last + new_time_spent_total = time_now - start_time_perpetual + curr_stats.time_average = (curr_stats.time_average*curr_stats.rounds_total + curr_stats.time_spent)/(curr_stats.rounds_total+1) curr_stats.rounds_total += 1 curr_stats.hash_rate_perpetual = (curr_stats.time_spent_total*curr_stats.hash_rate_perpetual + curr_stats.hash_rate)/ new_time_spent_total curr_stats.time_spent_total = new_time_spent_total @@ -780,14 +813,20 @@ def update_curr_block(block_number: int, block_bytes: bytes, diff: int, lock: mu logger.update(curr_stats, verbose=log_verbose) # exited while, found_solution contains the nonce or wallet is registered - if solution is not None: - stopEvent.set() # stop all other processes - logger.stop() + + stopEvent.set() # stop all other processes + logger.stop() - return solution + # terminate and wait for all solvers to exit + terminate_workers_and_wait_for_exit(solvers) + + return solution + +def terminate_workers_and_wait_for_exit(workers: List[multiprocessing.Process]) -> None: + for worker in workers: + worker.terminate() + worker.join() - logger.stop() - return None def create_pow( subtensor, diff --git a/tests/unit_tests/bittensor_tests/utils/test_utils.py b/tests/unit_tests/bittensor_tests/utils/test_utils.py index feb1807250..030cdbfb83 100644 --- a/tests/unit_tests/bittensor_tests/utils/test_utils.py +++ b/tests/unit_tests/bittensor_tests/utils/test_utils.py @@ -1,26 +1,24 @@ import binascii import hashlib -import unittest -import bittensor -import sys +import math +import multiprocessing +import os +import random import subprocess +import sys import time -import pytest -import os -import random -import torch -import multiprocessing +import unittest +from sys import platform from types import SimpleNamespace +from unittest.mock import MagicMock, patch -from sys import platform -from substrateinterface.base import Keypair +import bittensor +import pytest +import torch from _pytest.fixtures import fixture +from bittensor.utils import CUDASolver from loguru import logger - -from types import SimpleNamespace - -from unittest.mock import MagicMock, patch - +from substrateinterface.base import Keypair @fixture(scope="function") @@ -400,60 +398,55 @@ class MockException(Exception): assert call1[1]['call_function'] == 'register' call_params = call1[1]['call_params'] assert call_params['nonce'] == mock_result['nonce'] - - -def test_pow_called_for_cuda(): - class MockException(Exception): - pass - mock_compose_call = MagicMock(side_effect=MockException) - - mock_subtensor = bittensor.subtensor(_mock=True) - mock_subtensor.neuron_for_pubkey=MagicMock(is_null=True) - mock_subtensor.substrate = MagicMock( - __enter__= MagicMock(return_value=MagicMock( - compose_call=mock_compose_call - )), - __exit__ = MagicMock(return_value=None), - ) - - mock_wallet = SimpleNamespace( - hotkey=SimpleNamespace( - ss58_address='' - ), - coldkeypub=SimpleNamespace( - ss58_address='' - ) - ) - mock_result = { - "block_number": 1, - 'nonce': random.randint(0, pow(2, 32)), - 'work': b'\x00' * 64, - } +class TestCUDASolverRun(unittest.TestCase): + def test_multi_cuda_run_updates_nonce_start(self): + class MockException(Exception): + pass + + TPB: int = 512 + update_interval: int = 70_000 + nonce_limit: int = int(math.pow(2, 64)) - 1 + + mock_solver_self = MagicMock( + spec=CUDASolver, + TPB=TPB, + dev_id=0, + update_interval=update_interval, + stopEvent=MagicMock(is_set=MagicMock(return_value=False)), + newBlockEvent=MagicMock(is_set=MagicMock(return_value=False)), + finished_queue=MagicMock(put=MagicMock()), + limit=10000, + proc_num=0, + ) - with patch('bittensor.utils.POWNotStale', return_value=True) as mock_pow_not_stale: - with patch('torch.cuda.is_available', return_value=True) as mock_cuda_available: - with patch('bittensor.utils.create_pow', return_value=mock_result) as mock_create_pow: - with patch('bittensor.utils.hex_bytes_to_u8_list', return_value=b''): - - # Should exit early - with pytest.raises(MockException): - mock_subtensor.register(mock_wallet, cuda=True, prompt=False) - - mock_pow_not_stale.assert_called_once() - mock_create_pow.assert_called_once() - mock_cuda_available.assert_called_once() - - call0 = mock_pow_not_stale.call_args - assert call0[0][0] == mock_subtensor - assert call0[0][1] == mock_result - - mock_compose_call.assert_called_once() - call1 = mock_compose_call.call_args - assert call1[1]['call_function'] == 'register' - call_params = call1[1]['call_params'] - assert call_params['nonce'] == mock_result['nonce'] + + with patch('bittensor.utils.solve_for_nonce_block_cuda', + side_effect=[None, MockException] # first call returns mocked no solution, second call raises exception + ) as mock_solve_for_nonce_block_cuda: + + # Should exit early + with pytest.raises(MockException): + CUDASolver.run(mock_solver_self) + + mock_solve_for_nonce_block_cuda.assert_called() + calls = mock_solve_for_nonce_block_cuda.call_args_list + self.assertEqual(len(calls), 2, f"solve_for_nonce_block_cuda was called {len(calls)}. Expected 2") # called only twice + + # args, kwargs + args_call_0, _ = calls[0] + initial_nonce_start: int = args_call_0[1] # second arg should be nonce_start + self.assertIsInstance(initial_nonce_start, int) + + args_call_1, _ = calls[1] + nonce_start_after_iteration: int = args_call_1[1] # second arg should be nonce_start + self.assertIsInstance(nonce_start_after_iteration, int) + + # verify nonce_start is updated after each iteration + self.assertNotEqual(nonce_start_after_iteration, initial_nonce_start, "nonce_start was not updated after iteration") + ## Should incerase by the number of nonces tried == TPB * update_interval + self.assertEqual(nonce_start_after_iteration, (initial_nonce_start + update_interval * TPB) % nonce_limit, "nonce_start was not updated by the correct amount") if __name__ == "__main__": - test_solve_for_difficulty_fast_registered_already() + unittest.main()