From ecf6915e7523e0c83a815d09370d512dd3ea46d0 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sun, 3 Jul 2022 17:45:23 +0200 Subject: [PATCH] Adding explicit json_serialize argument in AIOHTTPTransport --- gql/transport/aiohttp.py | 15 +++++++++----- tests/test_aiohttp.py | 45 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 6d51f4f3..f4f38b69 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -4,7 +4,7 @@ import json import logging from ssl import SSLContext -from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Type, Union +from typing import Any, AsyncGenerator, Callable, Dict, Optional, Tuple, Type, Union import aiohttp from aiohttp.client_exceptions import ClientResponseError @@ -49,6 +49,7 @@ def __init__( ssl: Union[SSLContext, bool, Fingerprint] = False, timeout: Optional[int] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, + json_serialize: Callable = json.dumps, client_session_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given aiohttp parameters. @@ -61,6 +62,8 @@ def __init__( :param ssl: ssl_context of the connection. Use ssl=False to disable encryption :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection to close properly + :param json_serialize: Json serializer callable. + By default json.dumps() function :param client_session_args: Dict of extra args passed to `aiohttp.ClientSession`_ @@ -77,6 +80,7 @@ def __init__( self.client_session_args = client_session_args self.session: Optional[aiohttp.ClientSession] = None self.response_headers: Optional[CIMultiDictProxy[str]] + self.json_serialize: Callable = json_serialize async def connect(self) -> None: """Coroutine which will create an aiohttp ClientSession() as self.session. @@ -96,6 +100,7 @@ async def connect(self) -> None: "auth": None if isinstance(self.auth, AppSyncAuthentication) else self.auth, + "json_serialize": self.json_serialize, } if self.timeout is not None: @@ -248,14 +253,14 @@ async def execute( file_streams = {str(i): files[path] for i, path in enumerate(files)} # Add the payload to the operations field - operations_str = json.dumps(payload) + operations_str = self.json_serialize(payload) log.debug("operations %s", operations_str) data.add_field( "operations", operations_str, content_type="application/json" ) # Add the file map field - file_map_str = json.dumps(file_map) + file_map_str = self.json_serialize(file_map) log.debug("file_map %s", file_map_str) data.add_field("map", file_map_str, content_type="application/json") @@ -270,7 +275,7 @@ async def execute( payload["variables"] = variable_values if log.isEnabledFor(logging.INFO): - log.info(">>> %s", json.dumps(payload)) + log.info(">>> %s", self.json_serialize(payload)) post_args = {"json": payload} @@ -281,7 +286,7 @@ async def execute( # Add headers for AppSync if requested if isinstance(self.auth, AppSyncAuthentication): post_args["headers"] = self.auth.get_headers( - json.dumps(payload), + self.json_serialize(payload), {"content-type": "application/json"}, ) diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 4a70956c..3a84d21e 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1339,3 +1339,48 @@ async def handler(request): assert expected_warning in caplog.text await client.close_async() + + +@pytest.mark.asyncio +async def test_aiohttp_json_serializer(event_loop, aiohttp_server, caplog): + from aiohttp import web + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + + request_text = await request.text() + print("Received on backend: " + request_text) + + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport( + url=url, + timeout=10, + json_serialize=lambda e: json.dumps(e, separators=(",", ":")), + ) + + async with Client(transport=transport) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + # Checking that there is no space after the colon in the log + expected_log = '"query":"query getContinents' + assert expected_log in caplog.text