From 0fd8c4c7ffc78e29b49301c2972d675dfaa2e5b2 Mon Sep 17 00:00:00 2001 From: Benjamin Himes Date: Fri, 22 Nov 2024 16:26:45 +0200 Subject: [PATCH] Brings this implementation of async_substrate_interface.py up to date with the btcli version. --- bittensor/utils/async_substrate_interface.py | 68 +++++++++++--------- 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/bittensor/utils/async_substrate_interface.py b/bittensor/utils/async_substrate_interface.py index 982e8dfc96..ae4e44529b 100644 --- a/bittensor/utils/async_substrate_interface.py +++ b/bittensor/utils/async_substrate_interface.py @@ -22,6 +22,8 @@ from websockets.asyncio.client import connect from websockets.exceptions import ConnectionClosed +from bittensor.utils import hex_to_bytes + if TYPE_CHECKING: from websockets.asyncio.client import ClientConnection @@ -464,6 +466,9 @@ def __init__(self, chain, runtime_config, metadata, type_registry): self.runtime_config = runtime_config self.metadata = metadata + def __str__(self): + return f"Runtime: {self.chain} | {self.config}" + @property def implements_scaleinfo(self) -> bool: """ @@ -647,15 +652,12 @@ async def __aenter__(self): self._exit_task.cancel() if not self._initialized: self._initialized = True - await self._connect() + self.ws = await asyncio.wait_for( + connect(self.ws_url, **self._options), timeout=10 + ) self._receiving_task = asyncio.create_task(self._start_receiving()) return self - async def _connect(self): - self.ws = await asyncio.wait_for( - connect(self.ws_url, **self._options), timeout=10 - ) - async def __aexit__(self, exc_type, exc_val, exc_tb): async with self._lock: self._in_use -= 1 @@ -696,7 +698,7 @@ async def shutdown(self): async def _recv(self) -> None: try: - response = json.loads(await cast(ClientConnection, self.ws).recv()) + response = json.loads(await self.ws.recv()) async with self._lock: self._open_subscriptions -= 1 if "id" in response: @@ -770,11 +772,13 @@ def __init__( """ self.chain_endpoint = chain_endpoint self.__chain = chain_name - options = { - "max_size": 2**32, - "write_limit": 2**16, - } - self.ws = Websocket(chain_endpoint, options=options) + self.ws = Websocket( + chain_endpoint, + options={ + "max_size": 2**32, + "write_limit": 2**16, + }, + ) self._lock = asyncio.Lock() self.last_block_hash: Optional[str] = None self.config = { @@ -896,9 +900,10 @@ async def init_runtime( async def get_runtime(block_hash, block_id) -> Runtime: # Check if runtime state already set to current block - if (block_hash and block_hash == self.last_block_hash) or ( - block_id and block_id == self.block_id - ): + if ( + (block_hash and block_hash == self.last_block_hash) + or (block_id and block_id == self.block_id) + ) and self.metadata is not None: return Runtime( self.chain, self.runtime_config, @@ -944,9 +949,11 @@ async def get_runtime(block_hash, block_id) -> Runtime: raise SubstrateRequestException( f"No runtime information for block '{block_hash}'" ) - # Check if runtime state already set to current block - if runtime_info.get("specVersion") == self.runtime_version: + if ( + runtime_info.get("specVersion") == self.runtime_version + and self.metadata is not None + ): return Runtime( self.chain, self.runtime_config, @@ -961,16 +968,19 @@ async def get_runtime(block_hash, block_id) -> Runtime: if self.runtime_version in self.__metadata_cache: # Get metadata from cache # self.debug_message('Retrieved metadata for {} from memory'.format(self.runtime_version)) - self.metadata = self.__metadata_cache[self.runtime_version] + metadata = self.metadata = self.__metadata_cache[ + self.runtime_version + ] else: - self.metadata = await self.get_block_metadata( + metadata = self.metadata = await self.get_block_metadata( block_hash=runtime_block_hash, decode=True ) # self.debug_message('Retrieved metadata for {} from Substrate node'.format(self.runtime_version)) # Update metadata cache self.__metadata_cache[self.runtime_version] = self.metadata - + else: + metadata = self.metadata # Update type registry self.reload_type_registry(use_remote_preset=False, auto_discover=True) @@ -1011,7 +1021,10 @@ async def get_runtime(block_hash, block_id) -> Runtime: if block_id and block_hash: raise ValueError("Cannot provide block_hash and block_id at the same time") - if not (runtime := self.runtime_cache.retrieve(block_id, block_hash)): + if ( + not (runtime := self.runtime_cache.retrieve(block_id, block_hash)) + or runtime.metadata is None + ): runtime = await get_runtime(block_hash, block_id) self.runtime_cache.add_item(block_id, block_hash, runtime) return runtime @@ -2271,7 +2284,7 @@ async def get_metadata_constant(self, module_name, constant_name, block_hash=Non MetadataModuleConstants """ - # await self.init_runtime(block_hash=block_hash) + await self.init_runtime(block_hash=block_hash) for module in self.metadata.pallets: if module_name == module.name and module.constants: @@ -2285,7 +2298,7 @@ async def get_constant( constant_name: str, block_hash: Optional[str] = None, reuse_block_hash: bool = False, - ) -> "ScaleType": + ) -> Optional["ScaleType"]: """ Returns the decoded `ScaleType` object of the constant for given module name, call function name and block_hash (or chaintip if block_hash is omitted) @@ -2364,7 +2377,7 @@ async def query( raw_storage_key: Optional[bytes] = None, subscription_handler=None, reuse_block_hash: bool = False, - ) -> Union["ScaleType"]: + ) -> "ScaleType": """ Queries subtensor. This should only be used when making a single request. For multiple requests, you should use ``self.query_multiple`` @@ -2551,10 +2564,7 @@ def concat_hash_len(key_hasher: str) -> int: item_key = None try: - try: - item_bytes = bytes.fromhex(item[1][2:]) - except ValueError: - item_bytes = bytes.fromhex(item[1]) + item_bytes = hex_to_bytes(item[1]) item_value = await self.decode_scale( type_string=value_type, @@ -2720,7 +2730,7 @@ async def get_metadata_call_function( return call return None - async def get_block_number(self, block_hash: Optional[str] = None) -> int: + async def get_block_number(self, block_hash: Optional[str]) -> int: """Async version of `substrateinterface.base.get_block_number` method.""" response = await self.rpc_request("chain_getHeader", [block_hash])