diff --git a/async_substrate_interface/async_substrate.py b/async_substrate_interface/async_substrate.py index efeb24b..7950c64 100644 --- a/async_substrate_interface/async_substrate.py +++ b/async_substrate_interface/async_substrate.py @@ -685,7 +685,7 @@ def __init__( self.ws = Websocket( url, options={ - "max_size": 2**32, + "max_size": self.ws_max_size, "write_limit": 2**16, }, ) diff --git a/async_substrate_interface/sync_substrate.py b/async_substrate_interface/sync_substrate.py index 0cb87d1..105ddde 100644 --- a/async_substrate_interface/sync_substrate.py +++ b/async_substrate_interface/sync_substrate.py @@ -15,6 +15,7 @@ ) from scalecodec.base import RuntimeConfigurationObject, ScaleBytes, ScaleType from websockets.sync.client import connect +from websockets.exceptions import ConnectionClosed from async_substrate_interface.errors import ( ExtrinsicNotFound, @@ -507,6 +508,7 @@ def __init__( self._metadata_cache = {} self.metadata_version_hex = "0x0f000000" # v15 self.reload_type_registry() + self.ws = self.connect(init=True) if not _mock: self.initialize() @@ -527,7 +529,7 @@ def initialize(self): self.initialized = True def __exit__(self, exc_type, exc_val, exc_tb): - pass + self.ws.close() @property def properties(self): @@ -562,6 +564,15 @@ def name(self): self._name = self.rpc_request("system_name", []).get("result") return self._name + def connect(self, init=False): + if init is True: + return connect(self.chain_endpoint, max_size=self.ws_max_size) + else: + if not self.ws.close_code: + return self.ws + else: + return connect(self.chain_endpoint, max_size=self.ws_max_size) + def get_storage_item(self, module: str, storage_function: str): if not self._metadata: self.init_runtime() @@ -1620,69 +1631,67 @@ def _make_rpc_request( _received = {} subscription_added = False - with connect(self.chain_endpoint, max_size=2**32) as ws: - item_id = 0 - for payload in payloads: - item_id += 1 - ws.send(json.dumps({**payload["payload"], **{"id": item_id}})) - request_manager.add_request(item_id, payload["id"]) - - while True: - try: - response = json.loads( - ws.recv(timeout=self.retry_timeout, decode=False) + ws = self.connect(init=False if attempt == 1 else True) + item_id = 0 + for payload in payloads: + item_id += 1 + ws.send(json.dumps({**payload["payload"], **{"id": item_id}})) + request_manager.add_request(item_id, payload["id"]) + + while True: + try: + response = json.loads(ws.recv(timeout=self.retry_timeout, decode=False)) + except (TimeoutError, ConnectionClosed): + if attempt >= self.max_retries: + logging.warning( + f"Timed out waiting for RPC requests {attempt} times. Exiting." ) - except TimeoutError: - if attempt >= self.max_retries: - logging.warning( - f"Timed out waiting for RPC requests {attempt} times. Exiting." - ) - raise SubstrateRequestException("Max retries reached.") - else: - return self._make_rpc_request( - payloads, + raise SubstrateRequestException("Max retries reached.") + else: + return self._make_rpc_request( + payloads, + value_scale_type, + storage_item, + result_handler, + attempt + 1, + ) + if "id" in response: + _received[response["id"]] = response + elif "params" in response: + _received[response["params"]["subscription"]] = response + else: + raise SubstrateRequestException(response) + for item_id in list(request_manager.response_map.keys()): + if item_id not in request_manager.responses or isinstance( + result_handler, Callable + ): + if response := _received.pop(item_id): + if ( + isinstance(result_handler, Callable) + and not subscription_added + ): + # handles subscriptions, overwrites the previous mapping of {item_id : payload_id} + # with {subscription_id : payload_id} + try: + item_id = request_manager.overwrite_request( + item_id, response["result"] + ) + subscription_added = True + except KeyError: + raise SubstrateRequestException(str(response)) + decoded_response, complete = self._process_response( + response, + item_id, value_scale_type, storage_item, result_handler, - attempt + 1, ) - if "id" in response: - _received[response["id"]] = response - elif "params" in response: - _received[response["params"]["subscription"]] = response - else: - raise SubstrateRequestException(response) - for item_id in list(request_manager.response_map.keys()): - if item_id not in request_manager.responses or isinstance( - result_handler, Callable - ): - if response := _received.pop(item_id): - if ( - isinstance(result_handler, Callable) - and not subscription_added - ): - # handles subscriptions, overwrites the previous mapping of {item_id : payload_id} - # with {subscription_id : payload_id} - try: - item_id = request_manager.overwrite_request( - item_id, response["result"] - ) - subscription_added = True - except KeyError: - raise SubstrateRequestException(str(response)) - decoded_response, complete = self._process_response( - response, - item_id, - value_scale_type, - storage_item, - result_handler, - ) - request_manager.add_response( - item_id, decoded_response, complete - ) + request_manager.add_response( + item_id, decoded_response, complete + ) - if request_manager.is_complete: - break + if request_manager.is_complete: + break return request_manager.get_results() @@ -2874,9 +2883,8 @@ def close(self): """ Closes the substrate connection, and the websocket connection. """ - # TODO change this logic try: - self.ws.shutdown() + self.ws.close() except AttributeError: pass diff --git a/async_substrate_interface/types.py b/async_substrate_interface/types.py index b4c38ed..daaaafc 100644 --- a/async_substrate_interface/types.py +++ b/async_substrate_interface/types.py @@ -344,6 +344,7 @@ class SubstrateMixin(ABC): runtime_config: RuntimeConfigurationObject type_registry: Optional[dict] ss58_format: Optional[int] + ws_max_size = 2**32 @property def chain(self):