diff --git a/gql/client.py b/gql/client.py index a79d4b72..0d9e36c7 100644 --- a/gql/client.py +++ b/gql/client.py @@ -106,7 +106,7 @@ def __init__( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. Default: False. :param parse_results: Whether gql will try to parse the serialized output - sent by the backend. Can be used to unserialize custom scalars or enums. + sent by the backend. Can be used to deserialize custom scalars or enums. :param batch_interval: Time to wait in seconds for batching requests together. Batching is disabled (by default) if 0. :param batch_max: Maximum number of requests in a single batch. @@ -892,7 +892,7 @@ def _execute( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. The extra arguments are passed to the transport execute method.""" @@ -1006,7 +1006,7 @@ def execute( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. :param get_execution_result: return the full ExecutionResult instance instead of only the "data" field. Necessary if you want to get the "extensions" field. @@ -1057,7 +1057,7 @@ def _execute_batch( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. :param validate_document: Whether we still need to validate the document. @@ -1151,7 +1151,7 @@ def execute_batch( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. :param get_execution_result: return the full ExecutionResult instance instead of only the "data" field. Necessary if you want to get the "extensions" field. @@ -1333,7 +1333,7 @@ async def _subscribe( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. The extra arguments are passed to the transport subscribe method.""" @@ -1454,7 +1454,7 @@ async def subscribe( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. :param get_execution_result: yield the full ExecutionResult instance instead of only the "data" field. Necessary if you want to get the "extensions" field. @@ -1511,7 +1511,7 @@ async def _execute( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. The extra arguments are passed to the transport execute method.""" @@ -1617,7 +1617,7 @@ async def execute( :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. - :param parse_result: Whether gql will unserialize the result. + :param parse_result: Whether gql will deserialize the result. By default use the parse_results argument of the client. :param get_execution_result: return the full ExecutionResult instance instead of only the "data" field. Necessary if you want to get the "extensions" field. diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 60f42c94..be22ce9c 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -50,6 +50,7 @@ def __init__( timeout: Optional[int] = None, ssl_close_timeout: Optional[Union[int, float]] = 10, json_serialize: Callable = json.dumps, + json_deserialize: Callable = json.loads, client_session_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given aiohttp parameters. @@ -64,6 +65,8 @@ def __init__( to close properly :param json_serialize: Json serializer callable. By default json.dumps() function + :param json_deserialize: Json deserializer callable. + By default json.loads() function :param client_session_args: Dict of extra args passed to `aiohttp.ClientSession`_ @@ -81,6 +84,7 @@ def __init__( self.session: Optional[aiohttp.ClientSession] = None self.response_headers: Optional[CIMultiDictProxy[str]] self.json_serialize: Callable = json_serialize + self.json_deserialize: Callable = json_deserialize async def connect(self) -> None: """Coroutine which will create an aiohttp ClientSession() as self.session. @@ -328,7 +332,7 @@ async def raise_response_error(resp: aiohttp.ClientResponse, reason: str): ) try: - result = await resp.json(content_type=None) + result = await resp.json(loads=self.json_deserialize, content_type=None) if log.isEnabledFor(logging.INFO): result_text = await resp.text() diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index cfc25dc9..811601b8 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -38,6 +38,7 @@ def __init__( self, url: Union[str, httpx.URL], json_serialize: Callable = json.dumps, + json_deserialize: Callable = json.loads, **kwargs, ): """Initialize the transport with the given httpx parameters. @@ -45,10 +46,13 @@ def __init__( :param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'. :param json_serialize: Json serializer callable. By default json.dumps() function. + :param json_deserialize: Json deserializer callable. + By default json.loads() function. :param kwargs: Extra args passed to the `httpx` client. """ self.url = url self.json_serialize = json_serialize + self.json_deserialize = json_deserialize self.kwargs = kwargs def _prepare_request( @@ -145,7 +149,7 @@ def _prepare_result(self, response: httpx.Response) -> ExecutionResult: log.debug("<<< %s", response.text) try: - result: Dict[str, Any] = response.json() + result: Dict[str, Any] = self.json_deserialize(response.content) except Exception: self._raise_response_error(response, "Not a JSON answer") diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 09259e51..b16964d0 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1511,6 +1511,56 @@ async def handler(request): assert expected_log in caplog.text +query_float_str = """ + query getPi { + pi + } +""" + +query_float_server_answer_data = '{"pi": 3.141592653589793238462643383279502884197}' + +query_float_server_answer = f'{{"data":{query_float_server_answer_data}}}' + + +@pytest.mark.asyncio +async def test_aiohttp_json_deserializer(event_loop, aiohttp_server): + from aiohttp import web + from decimal import Decimal + from functools import partial + from gql.transport.aiohttp import AIOHTTPTransport + + async def handler(request): + return web.Response( + text=query_float_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + json_loads = partial(json.loads, parse_float=Decimal) + + transport = AIOHTTPTransport( + url=url, + timeout=10, + json_deserialize=json_loads, + ) + + async with Client(transport=transport) as session: + + query = gql(query_float_str) + + # Execute query asynchronously + result = await session.execute(query) + + pi = result["pi"] + + assert pi == Decimal("3.141592653589793238462643383279502884197") + + @pytest.mark.asyncio async def test_aiohttp_connector_owner_false(event_loop, aiohttp_server): from aiohttp import web, TCPConnector diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index e5be73ec..3665f5d8 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -1389,3 +1389,54 @@ async def handler(request): # Checking that there is no space after the colon in the log expected_log = '"query":"query getContinents' assert expected_log in caplog.text + + +query_float_str = """ + query getPi { + pi + } +""" + +query_float_server_answer_data = '{"pi": 3.141592653589793238462643383279502884197}' + +query_float_server_answer = f'{{"data":{query_float_server_answer_data}}}' + + +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_json_deserializer(event_loop, aiohttp_server): + from aiohttp import web + from decimal import Decimal + from functools import partial + from gql.transport.httpx import HTTPXAsyncTransport + + async def handler(request): + return web.Response( + text=query_float_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + json_loads = partial(json.loads, parse_float=Decimal) + + transport = HTTPXAsyncTransport( + url=url, + timeout=10, + json_deserialize=json_loads, + ) + + async with Client(transport=transport) as session: + + query = gql(query_float_str) + + # Execute query asynchronously + result = await session.execute(query) + + pi = result["pi"] + + assert pi == Decimal("3.141592653589793238462643383279502884197")