diff --git a/gql/client.py b/gql/client.py index 91cbcde6..5203d17d 100644 --- a/gql/client.py +++ b/gql/client.py @@ -271,8 +271,15 @@ async def __aenter__(self): self.session = AsyncClientSession(client=self) # Get schema from transport if needed - if self.fetch_schema_from_transport and not self.schema: - await self.session.fetch_schema() + try: + if self.fetch_schema_from_transport and not self.schema: + await self.session.fetch_schema() + except Exception: + # we don't know what type of exception is thrown here because it + # depends on the underlying transport; we just make sure that the + # transport is closed and re-raise the exception + await self.transport.close() + raise return self.session @@ -293,8 +300,15 @@ def __enter__(self): self.session = SyncClientSession(client=self) # Get schema from transport if needed - if self.fetch_schema_from_transport and not self.schema: - self.session.fetch_schema() + try: + if self.fetch_schema_from_transport and not self.schema: + self.session.fetch_schema() + except Exception: + # we don't know what type of exception is thrown here because it + # depends on the underlying transport; we just make sure that the + # transport is closed and re-raise the exception + self.transport.close() + raise return self.session diff --git a/tests/test_client.py b/tests/test_client.py index c8df40ee..fecdf43d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -200,3 +200,50 @@ def test_gql(): client = Client(schema=schema) result = client.execute(query) assert result["user"] is None + + +@pytest.mark.requests +def test_sync_transport_close_on_schema_retrieval_failure(): + """ + Ensure that the transport session is closed if an error occurs when + entering the context manager (e.g., because schema retrieval fails) + """ + + from gql.transport.requests import RequestsHTTPTransport + + transport = RequestsHTTPTransport(url="http://localhost/") + client = Client(transport=transport, fetch_schema_from_transport=True) + + try: + with client: + pass + except Exception: + # we don't care what exception is thrown, we just want to check if the + # transport is closed afterwards + pass + + assert client.transport.session is None + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_async_transport_close_on_schema_retrieval_failure(): + """ + Ensure that the transport session is closed if an error occurs when + entering the context manager (e.g., because schema retrieval fails) + """ + + from gql.transport.aiohttp import AIOHTTPTransport + + transport = AIOHTTPTransport(url="http://localhost/") + client = Client(transport=transport, fetch_schema_from_transport=True) + + try: + async with client: + pass + except Exception: + # we don't care what exception is thrown, we just want to check if the + # transport is closed afterwards + pass + + assert client.transport.session is None