From 133485e00bf8baffb4ac5215ab590158253f9883 Mon Sep 17 00:00:00 2001 From: Benjamin Himes Date: Wed, 27 Nov 2024 18:14:31 +0200 Subject: [PATCH 1/3] Fixes the `AsyncSubstrateInterface._get_block_handler` method's `result_handler` and adds a `wait_for_block` method. --- bittensor/utils/async_substrate_interface.py | 87 +++++++++++++++----- 1 file changed, 66 insertions(+), 21 deletions(-) diff --git a/bittensor/utils/async_substrate_interface.py b/bittensor/utils/async_substrate_interface.py index 52e697b333..9e97793afd 100644 --- a/bittensor/utils/async_substrate_interface.py +++ b/bittensor/utils/async_substrate_interface.py @@ -5,6 +5,7 @@ """ import asyncio +import inspect import json import random from collections import defaultdict @@ -1171,14 +1172,14 @@ async def _get_block_handler( include_author: bool = False, header_only: bool = False, finalized_only: bool = False, - subscription_handler: Optional[Callable] = None, + subscription_handler: Optional[Callable[[dict], Awaitable[Any]]] = None, ): try: await self.init_runtime(block_hash=block_hash) except BlockNotFound: return None - async def decode_block(block_data, block_data_hash=None): + async def decode_block(block_data, block_data_hash=None) -> dict[str, Any]: if block_data: if block_data_hash: block_data["header"]["hash"] = block_data_hash @@ -1193,12 +1194,12 @@ async def decode_block(block_data, block_data_hash=None): if "extrinsics" in block_data: for idx, extrinsic_data in enumerate(block_data["extrinsics"]): - extrinsic_decoder = extrinsic_cls( - data=ScaleBytes(extrinsic_data), - metadata=self.__metadata, - runtime_config=self.runtime_config, - ) try: + extrinsic_decoder = extrinsic_cls( + data=ScaleBytes(extrinsic_data), + metadata=self.__metadata, + runtime_config=self.runtime_config, + ) extrinsic_decoder.decode(check_remaining=True) block_data["extrinsics"][idx] = extrinsic_decoder @@ -1314,23 +1315,29 @@ async def decode_block(block_data, block_data_hash=None): if callable(subscription_handler): rpc_method_prefix = "Finalized" if finalized_only else "New" - async def result_handler(message, update_nr, subscription_id): - new_block = await decode_block({"header": message["params"]["result"]}) + async def result_handler( + message: dict, subscription_id: str + ) -> tuple[Any, bool]: + reached = False + subscription_result = None + if "params" in message: + new_block = await decode_block( + {"header": message["params"]["result"]} + ) - subscription_result = subscription_handler( - new_block, update_nr, subscription_id - ) + subscription_result = await subscription_handler(new_block) - if subscription_result is not None: - # Handler returned end result: unsubscribe from further updates - self._forgettable_task = asyncio.create_task( - self.rpc_request( - f"chain_unsubscribe{rpc_method_prefix}Heads", - [subscription_id], + if subscription_result is not None: + reached = True + # Handler returned end result: unsubscribe from further updates + self._forgettable_task = asyncio.create_task( + self.rpc_request( + f"chain_unsubscribe{rpc_method_prefix}Heads", + [subscription_id], + ) ) - ) - return subscription_result + return subscription_result, reached result = await self._make_rpc_request( [ @@ -1343,7 +1350,7 @@ async def result_handler(message, update_nr, subscription_id): result_handler=result_handler, ) - return result + return result["_get_block_handler"][-1] else: if header_only: @@ -2770,3 +2777,41 @@ async def close(self): await self.ws.shutdown() except AttributeError: pass + + async def wait_for_block( + self, + block: int, + result_handler: Callable[[dict], Awaitable[Any]], + task_return: bool = True, + ) -> Union[asyncio.Task, Union[bool, Any]]: + """ + Executes the result_handler when the chain has reached the block specified. + + Args: + block: block number + result_handler: coroutine executed upon reaching the block number. This can be basically anything, but + must accept one single arg, a dict with the block data; whether you use this data or not is entirely + up to you. + task_return: True to immediately return the result of wait_for_block as an asyncio Task, False to wait + for the block to be reached, and return the result of the result handler. + """ + + async def _handler(block_data: dict[str, Any]): + required_number = block + number = block_data["header"]["number"] + if number >= required_number: + return await result_handler(block_data) or True + + args = inspect.getfullargspec(result_handler).args + if len(args) != 1: + raise ValueError( + "result_handler must take exactly one arg: the dict block data." + ) + + co = self._get_block_handler( + self.last_block_hash, subscription_handler=_handler + ) + if task_return is True: + return asyncio.create_task(co) + else: + return await co From 05f326e9536c60aeaae8c4c45cfdc12141b59d7a Mon Sep 17 00:00:00 2001 From: Benjamin Himes Date: Wed, 27 Nov 2024 18:26:00 +0200 Subject: [PATCH 2/3] Add unit tests --- .../utils/test_async_substrate_interface.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 tests/unit_tests/utils/test_async_substrate_interface.py diff --git a/tests/unit_tests/utils/test_async_substrate_interface.py b/tests/unit_tests/utils/test_async_substrate_interface.py new file mode 100644 index 0000000000..e7c77b9662 --- /dev/null +++ b/tests/unit_tests/utils/test_async_substrate_interface.py @@ -0,0 +1,38 @@ +import pytest +import asyncio +from bittensor.utils import async_substrate_interface +from typing import Any + + +@pytest.mark.asyncio +async def test_wait_for_block_invalid_result_handler(): + chain_interface = async_substrate_interface.AsyncSubstrateInterface( + "dummy_endpoint" + ) + + with pytest.raises(ValueError): + + async def dummy_handler( + block_data: dict[str, Any], extra_arg + ): # extra argument + return block_data.get("header", {}).get("number", -1) == 2 + + await chain_interface.wait_for_block( + block=2, result_handler=dummy_handler, task_return=False + ) + + +@pytest.mark.asyncio +async def test_wait_for_block_async_return(): + chain_interface = async_substrate_interface.AsyncSubstrateInterface( + "dummy_endpoint" + ) + + async def dummy_handler(block_data: dict[str, Any]) -> bool: + return block_data.get("header", {}).get("number", -1) == 2 + + result = await chain_interface.wait_for_block( + block=2, result_handler=dummy_handler, task_return=True + ) + + assert isinstance(result, asyncio.Task) From 34ddb95df7c136ceba0deff99f0e3dfaff201fd5 Mon Sep 17 00:00:00 2001 From: Benjamin Himes Date: Wed, 27 Nov 2024 18:46:50 +0200 Subject: [PATCH 3/3] Docstring update. --- bittensor/utils/async_substrate_interface.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/bittensor/utils/async_substrate_interface.py b/bittensor/utils/async_substrate_interface.py index 9e97793afd..cd50695492 100644 --- a/bittensor/utils/async_substrate_interface.py +++ b/bittensor/utils/async_substrate_interface.py @@ -2794,13 +2794,21 @@ async def wait_for_block( up to you. task_return: True to immediately return the result of wait_for_block as an asyncio Task, False to wait for the block to be reached, and return the result of the result handler. + + Returns: + Either an asyncio.Task (which contains the running subscription, and whose `result()` will contain the + return of the result_handler), or the result itself, depending on `task_return` flag. + Note that if your result_handler returns `None`, this method will return `True`, otherwise + the return will be the result of your result_handler. """ async def _handler(block_data: dict[str, Any]): required_number = block number = block_data["header"]["number"] if number >= required_number: - return await result_handler(block_data) or True + return ( + r if (r := await result_handler(block_data)) is not None else True + ) args = inspect.getfullargspec(result_handler).args if len(args) != 1: