Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions bittensor/core/async_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3819,7 +3819,7 @@ async def get_stake_info_for_coldkey(
block: Optional[int] = None,
block_hash: Optional[str] = None,
reuse_block: bool = False,
) -> Optional[list["StakeInfo"]]:
) -> list["StakeInfo"]:
"""
Retrieves the stake information for a given coldkey.

Expand All @@ -3830,7 +3830,7 @@ async def get_stake_info_for_coldkey(
reuse_block: Whether to reuse the last-used block hash.

Returns:
An optional list of StakeInfo objects, or ``None`` if no stake information is found.
List of StakeInfo objects.
"""
result = await self.query_runtime_api(
runtime_api="StakeInfoRuntimeApi",
Expand All @@ -3847,6 +3847,42 @@ async def get_stake_info_for_coldkey(
stakes: list[StakeInfo] = StakeInfo.list_from_dicts(result)
return [stake for stake in stakes if stake.stake > 0]

async def get_stake_info_for_coldkeys(
self,
coldkey_ss58s: list[str],
block: Optional[int] = None,
block_hash: Optional[str] = None,
reuse_block: bool = False,
) -> dict[str, list["StakeInfo"]]:
"""
Retrieves the stake information for multiple coldkeys.

Parameters:
coldkey_ss58s: A list of SS58 addresses of the coldkeys to query.
block: The block number at which to query the stake information.
block_hash: The hash of the blockchain block number for the query.
reuse_block: Whether to reuse the last-used block hash.

Returns:
The dictionary mapping coldkey addresses to a list of StakeInfo objects.
"""
query = await self.query_runtime_api(
runtime_api="StakeInfoRuntimeApi",
method="get_stake_info_for_coldkeys",
params=[coldkey_ss58s],
block=block,
block_hash=block_hash,
reuse_block=reuse_block,
)

if query is None:
return {}

return {
decode_account_id(ck): StakeInfo.list_from_dicts(st_info)
for ck, st_info in query
}

async def get_stake_for_hotkey(
self,
hotkey_ss58: str,
Expand Down
33 changes: 30 additions & 3 deletions bittensor/core/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2905,7 +2905,7 @@ def get_stake_info_for_coldkey(
block: The block number at which to query the stake information.

Returns:
An optional list of StakeInfo objects, or ``None`` if no stake information is found.
List of StakeInfo objects.
"""
result = self.query_runtime_api(
runtime_api="StakeInfoRuntimeApi",
Expand All @@ -2916,8 +2916,35 @@ def get_stake_info_for_coldkey(

if result is None:
return []
stakes: list[StakeInfo] = StakeInfo.list_from_dicts(result)
return [stake for stake in stakes if stake.stake > 0]
return StakeInfo.list_from_dicts(result)

def get_stake_info_for_coldkeys(
self, coldkey_ss58s: list[str], block: Optional[int] = None
) -> dict[str, list["StakeInfo"]]:
"""
Retrieves the stake information for multiple coldkeys.

Parameters:
coldkey_ss58s: A list of SS58 addresses of the coldkeys to query.
block: The block number at which to query the stake information.

Returns:
The dictionary mapping coldkey addresses to a list of StakeInfo objects.
"""
query = self.query_runtime_api(
runtime_api="StakeInfoRuntimeApi",
method="get_stake_info_for_coldkeys",
params=[coldkey_ss58s],
block=block,
)

if query is None:
return {}

return {
decode_account_id(ck): StakeInfo.list_from_dicts(st_info)
for ck, st_info in query
}

def get_stake_for_hotkey(
self, hotkey_ss58: str, netuid: int, block: Optional[int] = None
Expand Down
1 change: 1 addition & 0 deletions bittensor/extras/subtensor_api/staking.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, subtensor: Union["_Subtensor", "_AsyncSubtensor"]):
subtensor.get_stake_for_coldkey_and_hotkey
)
self.get_stake_info_for_coldkey = subtensor.get_stake_info_for_coldkey
self.get_stake_info_for_coldkeys = subtensor.get_stake_info_for_coldkeys
self.get_stake_movement_fee = subtensor.get_stake_movement_fee
self.get_stake_weight = subtensor.get_stake_weight
self.get_unstake_fee = subtensor.get_unstake_fee
Expand Down
122 changes: 122 additions & 0 deletions tests/unit_tests/test_async_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5645,3 +5645,125 @@ async def test_blocks_until_next_epoch_uses_default_tempo(subtensor, mocker):
spy_tempo.assert_not_awaited()
assert result is not None
assert isinstance(result, int)


@pytest.mark.asyncio
async def test_get_stake_info_for_coldkeys_none(subtensor, mocker):
"""Tests get_stake_info_for_coldkeys method when query_runtime_api returns None."""
# Preps
fake_coldkey_ss58s = ["coldkey1", "coldkey2"]
fake_block = 123
fake_block_hash = None
fake_reuse_block = False

mocked_query_runtime_api = mocker.AsyncMock(
autospec=subtensor.query_runtime_api, return_value=None
)
subtensor.query_runtime_api = mocked_query_runtime_api

# Call
result = await subtensor.get_stake_info_for_coldkeys(
coldkey_ss58s=fake_coldkey_ss58s,
block=fake_block,
block_hash=fake_block_hash,
reuse_block=fake_reuse_block,
)

# Asserts
assert result == {}
mocked_query_runtime_api.assert_called_once_with(
runtime_api="StakeInfoRuntimeApi",
method="get_stake_info_for_coldkeys",
params=[fake_coldkey_ss58s],
block=fake_block,
block_hash=fake_block_hash,
reuse_block=fake_reuse_block,
)


@pytest.mark.asyncio
async def test_get_stake_info_for_coldkeys_success(subtensor, mocker):
"""Tests get_stake_info_for_coldkeys method when query_runtime_api returns data."""
# Preps
fake_coldkey_ss58s = ["coldkey1", "coldkey2"]
fake_block = 123
fake_block_hash = None
fake_reuse_block = False

fake_ck1 = b"\x16:\xech\r\xde,g\x03R1\xb9\x88q\xe79\xb8\x88\x93\xae\xd2)?*\rp\xb2\xe62\xads\x1c"
fake_ck2 = b"\x17:\xech\r\xde,g\x03R1\xb9\x88q\xe79\xb8\x88\x93\xae\xd2)?*\rp\xb2\xe62\xads\x1d"
fake_decoded_ck1 = "decoded_coldkey1"
fake_decoded_ck2 = "decoded_coldkey2"

stake_info_dict_1 = {
"netuid": 1,
"hotkey": b"\x16:\xech\r\xde,g\x03R1\xb9\x88q\xe79\xb8\x88\x93\xae\xd2)?*\rp\xb2\xe62\xads\x1c",
"coldkey": fake_ck1,
"stake": 1000,
"locked": 0,
"emission": 100,
"drain": 0,
"is_registered": True,
}
stake_info_dict_2 = {
"netuid": 2,
"hotkey": b"\x17:\xech\r\xde,g\x03R1\xb9\x88q\xe79\xb8\x88\x93\xae\xd2)?*\rp\xb2\xe62\xads\x1d",
"coldkey": fake_ck2,
"stake": 2000,
"locked": 0,
"emission": 200,
"drain": 0,
"is_registered": False,
}

fake_query_result = [
(fake_ck1, [stake_info_dict_1]),
(fake_ck2, [stake_info_dict_2]),
]

mocked_query_runtime_api = mocker.AsyncMock(
autospec=subtensor.query_runtime_api, return_value=fake_query_result
)
subtensor.query_runtime_api = mocked_query_runtime_api

mocked_decode_account_id = mocker.patch.object(
async_subtensor,
"decode_account_id",
side_effect=[fake_decoded_ck1, fake_decoded_ck2],
)

mock_stake_info_1 = mocker.Mock(spec=StakeInfo)
mock_stake_info_2 = mocker.Mock(spec=StakeInfo)
mocked_stake_info_list_from_dicts = mocker.patch.object(
async_subtensor.StakeInfo,
"list_from_dicts",
side_effect=[[mock_stake_info_1], [mock_stake_info_2]],
)

# Call
result = await subtensor.get_stake_info_for_coldkeys(
coldkey_ss58s=fake_coldkey_ss58s,
block=fake_block,
block_hash=fake_block_hash,
reuse_block=fake_reuse_block,
)

# Asserts
assert result == {
fake_decoded_ck1: [mock_stake_info_1],
fake_decoded_ck2: [mock_stake_info_2],
}
mocked_query_runtime_api.assert_called_once_with(
runtime_api="StakeInfoRuntimeApi",
method="get_stake_info_for_coldkeys",
params=[fake_coldkey_ss58s],
block=fake_block,
block_hash=fake_block_hash,
reuse_block=fake_reuse_block,
)
mocked_decode_account_id.assert_has_calls(
[mocker.call(fake_ck1), mocker.call(fake_ck2)]
)
mocked_stake_info_list_from_dicts.assert_has_calls(
[mocker.call([stake_info_dict_1]), mocker.call([stake_info_dict_2])]
)
106 changes: 105 additions & 1 deletion tests/unit_tests/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4210,7 +4210,7 @@ def test_get_stake_weight(subtensor, mocker):
result = subtensor.get_stake_weight(netuid=netuid)

# Asserts
mock_determine_block_hash.assert_called_once_with(block=None)
mock_determine_block_hash.assert_called_once()
mocked_query.assert_called_once_with(
module="SubtensorModule",
storage_function="StakeWeight",
Expand Down Expand Up @@ -5762,3 +5762,107 @@ def test_blocks_until_next_epoch_uses_default_tempo(subtensor, mocker):
spy_tempo.assert_not_called()
assert result is not None
assert isinstance(result, int)


def test_get_stake_info_for_coldkeys_none(subtensor, mocker):
"""Tests get_stake_info_for_coldkeys method when query_runtime_api returns None."""
# Preps
fake_coldkey_ss58s = ["coldkey1", "coldkey2"]
fake_block = 123

mocked_query_runtime_api = mocker.patch.object(
subtensor, "query_runtime_api", return_value=None
)

# Call
result = subtensor.get_stake_info_for_coldkeys(
coldkey_ss58s=fake_coldkey_ss58s, block=fake_block
)

# Asserts
assert result == {}
mocked_query_runtime_api.assert_called_once_with(
runtime_api="StakeInfoRuntimeApi",
method="get_stake_info_for_coldkeys",
params=[fake_coldkey_ss58s],
block=fake_block,
)


def test_get_stake_info_for_coldkeys_success(subtensor, mocker):
"""Tests get_stake_info_for_coldkeys method when query_runtime_api returns data."""
# Preps
fake_coldkey_ss58s = ["coldkey1", "coldkey2"]
fake_block = 123

fake_ck1 = b"\x16:\xech\r\xde,g\x03R1\xb9\x88q\xe79\xb8\x88\x93\xae\xd2)?*\rp\xb2\xe62\xads\x1c"
fake_ck2 = b"\x17:\xech\r\xde,g\x03R1\xb9\x88q\xe79\xb8\x88\x93\xae\xd2)?*\rp\xb2\xe62\xads\x1d"
fake_decoded_ck1 = "decoded_coldkey1"
fake_decoded_ck2 = "decoded_coldkey2"

stake_info_dict_1 = {
"netuid": 5,
"hotkey": b"\x16:\xech\r\xde,g\x03R1\xb9\x88q\xe79\xb8\x88\x93\xae\xd2)?*\rp\xb2\xe62\xads\x1c",
"coldkey": fake_ck1,
"stake": 1000,
"locked": 0,
"emission": 100,
"drain": 0,
"is_registered": True,
}
stake_info_dict_2 = {
"netuid": 14,
"hotkey": b"\x17:\xech\r\xde,g\x03R1\xb9\x88q\xe79\xb8\x88\x93\xae\xd2)?*\rp\xb2\xe62\xads\x1d",
"coldkey": fake_ck2,
"stake": 2000,
"locked": 0,
"emission": 200,
"drain": 0,
"is_registered": False,
}

fake_query_result = [
(fake_ck1, [stake_info_dict_1]),
(fake_ck2, [stake_info_dict_2]),
]

mocked_query_runtime_api = mocker.patch.object(
subtensor, "query_runtime_api", return_value=fake_query_result
)

mocked_decode_account_id = mocker.patch.object(
subtensor_module,
"decode_account_id",
side_effect=[fake_decoded_ck1, fake_decoded_ck2],
)

mock_stake_info_1 = mocker.Mock(spec=StakeInfo)
mock_stake_info_2 = mocker.Mock(spec=StakeInfo)
mocked_stake_info_list_from_dicts = mocker.patch.object(
subtensor_module.StakeInfo,
"list_from_dicts",
side_effect=[[mock_stake_info_1], [mock_stake_info_2]],
)

# Call
result = subtensor.get_stake_info_for_coldkeys(
coldkey_ss58s=fake_coldkey_ss58s, block=fake_block
)

# Asserts
assert result == {
fake_decoded_ck1: [mock_stake_info_1],
fake_decoded_ck2: [mock_stake_info_2],
}
mocked_query_runtime_api.assert_called_once_with(
runtime_api="StakeInfoRuntimeApi",
method="get_stake_info_for_coldkeys",
params=[fake_coldkey_ss58s],
block=fake_block,
)
mocked_decode_account_id.assert_has_calls(
[mocker.call(fake_ck1), mocker.call(fake_ck2)]
)
mocked_stake_info_list_from_dicts.assert_has_calls(
[mocker.call([stake_info_dict_1]), mocker.call([stake_info_dict_2])]
)