Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions bittensor_cli/src/bittensor/extrinsics/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,34 @@
U16_MAX = 65535


async def get_current_weights_for_uid(
subtensor: SubtensorInterface,
netuid: int,
uid: int,
) -> dict[int, float]:
"""
Fetches the current weights set by a specific UID on a subnet.

Args:
subtensor: The SubtensorInterface instance.
netuid: The network UID (0 for root network).
uid: The UID of the neuron whose weights to fetch.

Returns:
A dictionary mapping destination netuid to normalized weight (0.0-1.0).
"""
weights_data = await subtensor.weights(netuid=netuid)
current_weights: dict[int, float] = {}

for validator_uid, weight_list in weights_data:
if validator_uid == uid:
for dest_netuid, raw_weight in weight_list:
current_weights[dest_netuid] = u16_normalized_float(raw_weight)
break

return current_weights


async def get_limits(subtensor: SubtensorInterface) -> tuple[int, float]:
# Get weight restrictions.
maw, mwl = await asyncio.gather(
Expand Down Expand Up @@ -459,17 +487,39 @@ async def _do_set_weights():

# Ask before moving on.
if prompt:
# Fetch current weights for comparison
print_verbose("Fetching current weights for comparison")
current_weights = await get_current_weights_for_uid(
subtensor, netuid=0, uid=my_uid
)

table = Table(
Column("[dark_orange]Netuid", justify="center", style="bold green"),
Column(
"[dark_orange]Weight", justify="center", style="bold light_goldenrod2"
),
Column("[dark_orange]Current", justify="center", style="dim"),
Column("[dark_orange]New", justify="center", style="bold light_goldenrod2"),
Column("[dark_orange]Change", justify="center"),
expand=False,
show_edge=False,
)

for netuid, weight in zip(netuids, formatted_weights):
table.add_row(str(netuid), f"{weight:.8f}")
for netuid, new_weight in zip(netuids, formatted_weights):
current_weight = current_weights.get(netuid, 0.0)
diff = new_weight - current_weight

# Format the difference with color and sign
if diff > 0.00000001:
diff_str = f"[green]+{diff:.8f}[/green]"
elif diff < -0.00000001:
diff_str = f"[red]{diff:.8f}[/red]"
else:
diff_str = "[dim]0.00000000[/dim]"

table.add_row(
str(netuid),
f"{current_weight:.8f}",
f"{new_weight:.8f}",
diff_str,
)

console.print(table)
if not Confirm.ask("\nDo you want to set these root weights?"):
Expand Down
133 changes: 133 additions & 0 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import numpy as np
import pytest
import typer

from bittensor_cli.cli import parse_mnemonic, CLIManager
from bittensor_cli.src.bittensor.extrinsics.root import (
get_current_weights_for_uid,
set_root_weights_extrinsic,
)
from unittest.mock import AsyncMock, patch, MagicMock, Mock


Expand Down Expand Up @@ -659,3 +664,131 @@ def test_stake_transfer_calls_proxy_validation():

# Assert that proxy validation was called
mock_proxy_validation.assert_called_once_with(valid_proxy, False)


# ============================================================================
# Tests for root weights difference display
# ============================================================================


@pytest.mark.asyncio
async def test_get_current_weights_for_uid_success():
"""Test fetching current weights for a specific UID."""
mock_subtensor = MagicMock()

# Mock weights data: [(uid, [(dest_netuid, raw_weight), ...]), ...]
mock_weights_data = [
(0, [(0, 32768), (1, 16384), (2, 16384)]),
(1, [(0, 65535), (1, 0), (2, 0)]),
]
mock_subtensor.weights = AsyncMock(return_value=mock_weights_data)

result = await get_current_weights_for_uid(mock_subtensor, netuid=0, uid=0)

mock_subtensor.weights.assert_called_once_with(netuid=0)
assert 0 in result
assert 1 in result
assert 2 in result
# 32768 / 65535 ≈ 0.5
assert abs(result[0] - 0.5) < 0.01


@pytest.mark.asyncio
async def test_get_current_weights_for_uid_not_found():
"""Test fetching weights for a UID that doesn't exist."""
mock_subtensor = MagicMock()
mock_weights_data = [
(0, [(0, 32768), (1, 16384)]),
(1, [(0, 65535)]),
]
mock_subtensor.weights = AsyncMock(return_value=mock_weights_data)

result = await get_current_weights_for_uid(mock_subtensor, netuid=0, uid=5)

assert result == {}


@pytest.mark.asyncio
async def test_get_current_weights_for_uid_empty():
"""Test fetching weights when the network has no weights set."""
mock_subtensor = MagicMock()
mock_subtensor.weights = AsyncMock(return_value=[])

result = await get_current_weights_for_uid(mock_subtensor, netuid=0, uid=0)

assert result == {}


@pytest.mark.asyncio
async def test_set_root_weights_fetches_current_weights_with_prompt():
"""Test that set_root_weights fetches current weights when prompt=True."""
mock_subtensor = MagicMock()
mock_wallet = MagicMock()
mock_subtensor.query = AsyncMock(return_value=0)

with (
patch("bittensor_cli.src.bittensor.extrinsics.root.unlock_key") as mock_unlock,
patch("bittensor_cli.src.bittensor.extrinsics.root.get_limits") as mock_limits,
patch(
"bittensor_cli.src.bittensor.extrinsics.root.get_current_weights_for_uid"
) as mock_get_current,
patch("bittensor_cli.src.bittensor.extrinsics.root.console"),
patch("bittensor_cli.src.bittensor.extrinsics.root.Confirm") as mock_confirm,
):
mock_unlock.return_value = MagicMock(success=True)
mock_limits.return_value = (1, 0.5)
mock_get_current.return_value = {0: 0.5, 1: 0.3, 2: 0.2}
mock_confirm.ask.return_value = False

netuids = np.array([0, 1, 2], dtype=np.int64)
weights = np.array([0.4, 0.3, 0.3], dtype=np.float32)

await set_root_weights_extrinsic(
subtensor=mock_subtensor,
wallet=mock_wallet,
netuids=netuids,
weights=weights,
prompt=True,
)

mock_get_current.assert_called_once_with(mock_subtensor, netuid=0, uid=0)


@pytest.mark.asyncio
async def test_set_root_weights_skips_current_weights_without_prompt():
"""Test that set_root_weights skips fetching current weights when prompt=False."""
mock_subtensor = MagicMock()
mock_wallet = MagicMock()
mock_subtensor.query = AsyncMock(return_value=0)
mock_subtensor.substrate = MagicMock()
mock_subtensor.substrate.compose_call = AsyncMock()
mock_subtensor.substrate.create_signed_extrinsic = AsyncMock()
mock_response = MagicMock()
mock_response.is_success = True
mock_subtensor.substrate.submit_extrinsic = AsyncMock(return_value=mock_response)

with (
patch("bittensor_cli.src.bittensor.extrinsics.root.unlock_key") as mock_unlock,
patch("bittensor_cli.src.bittensor.extrinsics.root.get_limits") as mock_limits,
patch(
"bittensor_cli.src.bittensor.extrinsics.root.get_current_weights_for_uid"
) as mock_get_current,
patch("bittensor_cli.src.bittensor.extrinsics.root.console"),
):
mock_unlock.return_value = MagicMock(success=True)
mock_limits.return_value = (1, 0.5)

netuids = np.array([0, 1, 2], dtype=np.int64)
weights = np.array([0.4, 0.3, 0.3], dtype=np.float32)

await set_root_weights_extrinsic(
subtensor=mock_subtensor,
wallet=mock_wallet,
netuids=netuids,
weights=weights,
prompt=False,
wait_for_inclusion=False,
wait_for_finalization=False,
)

mock_get_current.assert_not_called()