diff --git a/bittensor_cli/src/bittensor/extrinsics/root.py b/bittensor_cli/src/bittensor/extrinsics/root.py index f95d9990e..92b2b5938 100644 --- a/bittensor_cli/src/bittensor/extrinsics/root.py +++ b/bittensor_cli/src/bittensor/extrinsics/root.py @@ -48,6 +48,34 @@ U16_MAX = 65535 +async def get_current_weights_for_uid( + subtensor: SubtensorInterface, + netuid: int, + uid: int, +) -> dict[int, float]: + """ + Fetches the current weights set by a specific UID on a subnet. + + Args: + subtensor: The SubtensorInterface instance. + netuid: The network UID (0 for root network). + uid: The UID of the neuron whose weights to fetch. + + Returns: + A dictionary mapping destination netuid to normalized weight (0.0-1.0). + """ + weights_data = await subtensor.weights(netuid=netuid) + current_weights: dict[int, float] = {} + + for validator_uid, weight_list in weights_data: + if validator_uid == uid: + for dest_netuid, raw_weight in weight_list: + current_weights[dest_netuid] = u16_normalized_float(raw_weight) + break + + return current_weights + + async def get_limits(subtensor: SubtensorInterface) -> tuple[int, float]: # Get weight restrictions. maw, mwl = await asyncio.gather( @@ -459,17 +487,39 @@ async def _do_set_weights(): # Ask before moving on. if prompt: + # Fetch current weights for comparison + print_verbose("Fetching current weights for comparison") + current_weights = await get_current_weights_for_uid( + subtensor, netuid=0, uid=my_uid + ) + table = Table( Column("[dark_orange]Netuid", justify="center", style="bold green"), - Column( - "[dark_orange]Weight", justify="center", style="bold light_goldenrod2" - ), + Column("[dark_orange]Current", justify="center", style="dim"), + Column("[dark_orange]New", justify="center", style="bold light_goldenrod2"), + Column("[dark_orange]Change", justify="center"), expand=False, show_edge=False, ) - for netuid, weight in zip(netuids, formatted_weights): - table.add_row(str(netuid), f"{weight:.8f}") + for netuid, new_weight in zip(netuids, formatted_weights): + current_weight = current_weights.get(netuid, 0.0) + diff = new_weight - current_weight + + # Format the difference with color and sign + if diff > 0.00000001: + diff_str = f"[green]+{diff:.8f}[/green]" + elif diff < -0.00000001: + diff_str = f"[red]{diff:.8f}[/red]" + else: + diff_str = "[dim]0.00000000[/dim]" + + table.add_row( + str(netuid), + f"{current_weight:.8f}", + f"{new_weight:.8f}", + diff_str, + ) console.print(table) if not Confirm.ask("\nDo you want to set these root weights?"): diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index a061910e5..0f5218de5 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -1,7 +1,12 @@ +import numpy as np import pytest import typer from bittensor_cli.cli import parse_mnemonic, CLIManager +from bittensor_cli.src.bittensor.extrinsics.root import ( + get_current_weights_for_uid, + set_root_weights_extrinsic, +) from unittest.mock import AsyncMock, patch, MagicMock, Mock @@ -659,3 +664,131 @@ def test_stake_transfer_calls_proxy_validation(): # Assert that proxy validation was called mock_proxy_validation.assert_called_once_with(valid_proxy, False) + + +# ============================================================================ +# Tests for root weights difference display +# ============================================================================ + + +@pytest.mark.asyncio +async def test_get_current_weights_for_uid_success(): + """Test fetching current weights for a specific UID.""" + mock_subtensor = MagicMock() + + # Mock weights data: [(uid, [(dest_netuid, raw_weight), ...]), ...] + mock_weights_data = [ + (0, [(0, 32768), (1, 16384), (2, 16384)]), + (1, [(0, 65535), (1, 0), (2, 0)]), + ] + mock_subtensor.weights = AsyncMock(return_value=mock_weights_data) + + result = await get_current_weights_for_uid(mock_subtensor, netuid=0, uid=0) + + mock_subtensor.weights.assert_called_once_with(netuid=0) + assert 0 in result + assert 1 in result + assert 2 in result + # 32768 / 65535 ≈ 0.5 + assert abs(result[0] - 0.5) < 0.01 + + +@pytest.mark.asyncio +async def test_get_current_weights_for_uid_not_found(): + """Test fetching weights for a UID that doesn't exist.""" + mock_subtensor = MagicMock() + mock_weights_data = [ + (0, [(0, 32768), (1, 16384)]), + (1, [(0, 65535)]), + ] + mock_subtensor.weights = AsyncMock(return_value=mock_weights_data) + + result = await get_current_weights_for_uid(mock_subtensor, netuid=0, uid=5) + + assert result == {} + + +@pytest.mark.asyncio +async def test_get_current_weights_for_uid_empty(): + """Test fetching weights when the network has no weights set.""" + mock_subtensor = MagicMock() + mock_subtensor.weights = AsyncMock(return_value=[]) + + result = await get_current_weights_for_uid(mock_subtensor, netuid=0, uid=0) + + assert result == {} + + +@pytest.mark.asyncio +async def test_set_root_weights_fetches_current_weights_with_prompt(): + """Test that set_root_weights fetches current weights when prompt=True.""" + mock_subtensor = MagicMock() + mock_wallet = MagicMock() + mock_subtensor.query = AsyncMock(return_value=0) + + with ( + patch("bittensor_cli.src.bittensor.extrinsics.root.unlock_key") as mock_unlock, + patch("bittensor_cli.src.bittensor.extrinsics.root.get_limits") as mock_limits, + patch( + "bittensor_cli.src.bittensor.extrinsics.root.get_current_weights_for_uid" + ) as mock_get_current, + patch("bittensor_cli.src.bittensor.extrinsics.root.console"), + patch("bittensor_cli.src.bittensor.extrinsics.root.Confirm") as mock_confirm, + ): + mock_unlock.return_value = MagicMock(success=True) + mock_limits.return_value = (1, 0.5) + mock_get_current.return_value = {0: 0.5, 1: 0.3, 2: 0.2} + mock_confirm.ask.return_value = False + + netuids = np.array([0, 1, 2], dtype=np.int64) + weights = np.array([0.4, 0.3, 0.3], dtype=np.float32) + + await set_root_weights_extrinsic( + subtensor=mock_subtensor, + wallet=mock_wallet, + netuids=netuids, + weights=weights, + prompt=True, + ) + + mock_get_current.assert_called_once_with(mock_subtensor, netuid=0, uid=0) + + +@pytest.mark.asyncio +async def test_set_root_weights_skips_current_weights_without_prompt(): + """Test that set_root_weights skips fetching current weights when prompt=False.""" + mock_subtensor = MagicMock() + mock_wallet = MagicMock() + mock_subtensor.query = AsyncMock(return_value=0) + mock_subtensor.substrate = MagicMock() + mock_subtensor.substrate.compose_call = AsyncMock() + mock_subtensor.substrate.create_signed_extrinsic = AsyncMock() + mock_response = MagicMock() + mock_response.is_success = True + mock_subtensor.substrate.submit_extrinsic = AsyncMock(return_value=mock_response) + + with ( + patch("bittensor_cli.src.bittensor.extrinsics.root.unlock_key") as mock_unlock, + patch("bittensor_cli.src.bittensor.extrinsics.root.get_limits") as mock_limits, + patch( + "bittensor_cli.src.bittensor.extrinsics.root.get_current_weights_for_uid" + ) as mock_get_current, + patch("bittensor_cli.src.bittensor.extrinsics.root.console"), + ): + mock_unlock.return_value = MagicMock(success=True) + mock_limits.return_value = (1, 0.5) + + netuids = np.array([0, 1, 2], dtype=np.int64) + weights = np.array([0.4, 0.3, 0.3], dtype=np.float32) + + await set_root_weights_extrinsic( + subtensor=mock_subtensor, + wallet=mock_wallet, + netuids=netuids, + weights=weights, + prompt=False, + wait_for_inclusion=False, + wait_for_finalization=False, + ) + + mock_get_current.assert_not_called()